Skip to content

Commit 21de42f

Browse files
jiqing-fengLRL-ModelCloudZX-ModelCloudQubitiumIlyasMoutawwakil
authoredDec 19, 2024··
Enable GPTQModel (#2064)
* align gptq check to transformers for supporting cpu * fix comment * gptqmodel Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * compatible with auto-gptq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix compatible with auto-gptq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix compatible with auto-gptq linear Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert unrelated changes Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * gptqmodel need use checkpoint_format (#1) * need checkpoint_format * default value of checkpoint_format is gptq * fix quantize * fix quantize * fix quantize * Update quantizer.py * need convert to v1 before gptqmodel save * back checkpoint_format to gptq after convert * cleanup code * sym=False is not supported with auto-gptq * add comments * cleanup code * Update quantizer.py * always convert v2 to v1 if checkpoint_format = "gptq" * Update quantizer.py --------- Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> * Mod backend code (#2) * keep gptq_v2 if sym is false * use hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, and hf_gptqmodel_post_init * no need check backend * use device_map * cleanup * Update quantizer.py * move import --------- Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> * fix format and log Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix version check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable gptqmodel tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update check quant type Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Fix optimum compat (#3) * add meta info * cleanup * cleanup * The value of quantizer should be an array * Update quantizer.py * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * Update quantizer.py * cleanup * comment on meta * hf_select_quant_linear pass checkpoint_format * add todo fix * move convert code to quantizer.save() * Update quantizer.py * Optimize hf_convert_gptq_v2_to_v1_format() * Optimize hf_convert_gptq_v1_to_v2_format() * fix GPTQTestCUDA * hf_select_quant_linear() always set pack=True * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * GPTQQuantizer add backend * lower checkpoint_format and backend * cleanup * move backend to bottom * no need to check gptqmodel version for ipex support * Update import_utils.py * Update quantizer.py * fix UnboundLocalError: cannot access local variable 'version' where it is not associated with a value * make version var short * Update import_utils.py * fix unittest * use assertLessEqual --------- Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> Co-authored-by: LRL <lrl@lbx.dev> * fix format and convert v2 to v1 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * [Fix] all tensors not same device (#5) * fix device error * update gptqmodel version * fix test * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add gptqmodel tests which contains cpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix all auto-gptq tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * rm gptqmodel yaml Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable real cpu tests by fp32 Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix test model name Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * keep the original device setting when using auto-gptq Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Update optimum/gptq/quantizer.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> * Update optimum/gptq/quantizer.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai> Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Co-authored-by: LRL <lrl@lbx.dev> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent 0ea269f commit 21de42f

File tree

4 files changed

+227
-61
lines changed

4 files changed

+227
-61
lines changed
 

‎optimum/gptq/quantizer.py

+195-58
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,34 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import importlib
1516
import json
1617
import os
1718
from enum import Enum
1819
from logging import getLogger
1920
from typing import Any, Dict, List, Optional, Tuple, Union
2021

2122
import torch
23+
from packaging import version
2224
from torch import nn
2325
from tqdm.auto import tqdm
2426
from transformers import AutoTokenizer
2527
from transformers.pytorch_utils import Conv1D
2628
from transformers.utils.quantization_config import QuantizationMethod
2729

28-
from ..utils import is_accelerate_available, is_auto_gptq_available
30+
from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available
2931
from ..utils.modeling_utils import recurse_getattr
32+
from ..version import __version__ as optimum_version
3033
from .constants import GPTQ_CONFIG
3134
from .data import get_dataset, prepare_dataset
32-
from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen
35+
from .utils import (
36+
get_block_name_with_pattern,
37+
get_device,
38+
get_layers,
39+
get_preceding_modules,
40+
get_seqlen,
41+
nested_move_to,
42+
)
3343

3444

3545
if is_accelerate_available():
@@ -40,14 +50,27 @@
4050
from accelerate.hooks import remove_hook_from_module
4151

4252
if is_auto_gptq_available():
53+
from auto_gptq import __version__ as autogptq_version
4354
from auto_gptq import exllama_set_max_input_length
44-
from auto_gptq.modeling._utils import autogptq_post_init
55+
from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init
4556
from auto_gptq.quantization import GPTQ
46-
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
57+
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
58+
59+
if is_gptqmodel_available():
60+
from gptqmodel import exllama_set_max_input_length
61+
from gptqmodel.quantization import GPTQ
62+
from gptqmodel.utils.importer import hf_select_quant_linear
63+
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
64+
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
65+
from gptqmodel.version import __version__ as gptqmodel_version
4766

4867
logger = getLogger(__name__)
4968

5069

70+
def has_device_more_than_cpu():
71+
return torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())
72+
73+
5174
class ExllamaVersion(int, Enum):
5275
ONE = 1
5376
TWO = 2
@@ -74,10 +97,13 @@ def __init__(
7497
batch_size: int = 1,
7598
pad_token_id: Optional[int] = None,
7699
disable_exllama: bool = False,
77-
exllama_config: Dict[str, Any] = None,
100+
exllama_config: Optional[Dict[str, Any]] = None,
78101
max_input_length: Optional[int] = None,
79102
cache_block_outputs: Optional[bool] = True,
80103
modules_in_block_to_quantize: Optional[List[List[str]]] = None,
104+
checkpoint_format: str = "gptq",
105+
meta: Optional[Dict[str, any]] = None,
106+
backend: Optional[str] = None,
81107
*args,
82108
**kwargs,
83109
):
@@ -129,6 +155,13 @@ def __init__(
129155
List list of module names to quantize in the block specified. This argument is useful to exclude certain linear modules from being quantized.
130156
The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially.
131157
If not set, we will quantize all linear layers. Example: `inside_layer_modules=[["self_attention.query_key_value"], ["mlp.dense_h_to_4h"]]`
158+
checkpoint_format (`str`, *optional*, defaults to `gptq`):
159+
GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only.
160+
meta (`Dict[str, any]`, *optional*):
161+
Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
162+
i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
163+
backend (`str`, *optional*):
164+
Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
132165
"""
133166

134167
self.bits = bits
@@ -138,6 +171,9 @@ def __init__(
138171
self.desc_act = desc_act
139172
self.sym = sym
140173
self.true_sequential = true_sequential
174+
self.checkpoint_format = checkpoint_format.lower()
175+
self.meta = meta
176+
self.backend = backend.lower() if backend is not None else None
141177
self.use_cuda_fp16 = use_cuda_fp16
142178
self.model_seqlen = model_seqlen
143179
self.block_name_to_quantize = block_name_to_quantize
@@ -161,6 +197,8 @@ def __init__(
161197
"true_sequential",
162198
"quant_method",
163199
"modules_in_block_to_quantize",
200+
"checkpoint_format",
201+
"meta",
164202
]
165203

166204
if self.bits not in [2, 3, 4, 8]:
@@ -182,13 +220,49 @@ def __init__(
182220
)
183221
self.exllama_version = self.exllama_config["version"]
184222

223+
def select_quant_linear(self, device_map: Union[str, dict]):
224+
if is_gptqmodel_available():
225+
self.quant_linear = hf_select_quant_linear(
226+
bits=self.bits,
227+
group_size=self.group_size,
228+
desc_act=self.desc_act,
229+
sym=self.sym,
230+
checkpoint_format=self.checkpoint_format,
231+
meta=self.meta,
232+
device_map=device_map,
233+
backend=self.backend,
234+
)
235+
else:
236+
self.quant_linear = hf_select_quant_linear(
237+
use_triton=False,
238+
desc_act=self.desc_act,
239+
group_size=self.group_size,
240+
bits=self.bits,
241+
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
242+
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
243+
)
244+
185245
def to_dict(self):
186246
"""
187247
Returns the args in dict format.
188248
"""
189249
gptq_dict = {}
190250
for key in self.serialization_keys:
191251
gptq_dict[key] = getattr(self, key)
252+
253+
if gptq_dict.get("meta") is None:
254+
gptq_dict["meta"] = {}
255+
256+
meta = gptq_dict["meta"]
257+
# store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer
258+
if meta.get("quantizer") is None:
259+
meta["quantizer"] = [f"optimum:{optimum_version}"]
260+
261+
if is_gptqmodel_available():
262+
meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}")
263+
elif is_auto_gptq_available():
264+
meta["quantizer"].append(f"auto_gptq:{autogptq_version}")
265+
192266
return gptq_dict
193267

194268
@classmethod
@@ -205,7 +279,7 @@ def from_dict(cls, config_dict: Dict[str, Any]):
205279
"""
206280
return cls(**config_dict)
207281

208-
def convert_model(self, model: nn.Module):
282+
def convert_model(self, model: nn.Module, **kwargs):
209283
"""
210284
Convert the model to a GPTQ model by getting and replacing the layers.
211285
@@ -226,7 +300,11 @@ def convert_model(self, model: nn.Module):
226300
f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)"
227301
)
228302
del layers_to_be_replaced[name]
303+
304+
self.select_quant_linear(device_map=kwargs.get("device_map", None))
305+
229306
self._replace_by_quant_layers(model, layers_to_be_replaced)
307+
230308
return model
231309

232310
def get_no_split_module_classes(self, model):
@@ -253,15 +331,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
253331
name (`str`, defaults to `""`):
254332
To keep track of the name of the current module
255333
"""
256-
QuantLinear = dynamically_import_QuantLinear(
257-
use_triton=False,
258-
desc_act=self.desc_act,
259-
group_size=self.group_size,
260-
bits=self.bits,
261-
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
262-
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
263-
)
264-
if isinstance(module, QuantLinear):
334+
if isinstance(module, self.quant_linear):
265335
return
266336
for attr in dir(module):
267337
layer = getattr(module, attr)
@@ -279,20 +349,37 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
279349
in_features = layer.weight.shape[0]
280350
out_features = layer.weight.shape[1]
281351
bias = layer.bias is not None
282-
if not (self.desc_act) or self.group_size == -1:
283-
new_layer = QuantLinear(
352+
if is_gptqmodel_available():
353+
new_layer = self.quant_linear(
284354
self.bits,
285355
self.group_size,
356+
self.desc_act,
357+
self.sym,
286358
in_features,
287359
out_features,
288360
bias,
289-
use_cuda_fp16=self.use_cuda_fp16,
290361
weight_dtype=layer.weight.dtype,
291362
)
292363
else:
293-
new_layer = QuantLinear(
294-
self.bits, self.group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
295-
)
364+
if not (self.desc_act) or self.group_size == -1:
365+
new_layer = self.quant_linear(
366+
self.bits,
367+
self.group_size,
368+
in_features,
369+
out_features,
370+
bias,
371+
use_cuda_fp16=self.use_cuda_fp16,
372+
weight_dtype=layer.weight.dtype,
373+
)
374+
else:
375+
new_layer = self.quant_linear(
376+
self.bits,
377+
self.group_size,
378+
in_features,
379+
out_features,
380+
bias,
381+
weight_dtype=layer.weight.dtype,
382+
)
296383
new_layer.device = device
297384
setattr(module, attr, new_layer.to(device))
298385
for name1, child in module.named_children():
@@ -318,13 +405,41 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
318405
`nn.Module`: The quantized model
319406
"""
320407

321-
if not is_auto_gptq_available():
322-
raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`")
323-
if not torch.cuda.is_available():
324-
raise RuntimeError("No GPU found. A GPU is needed to quantize model.")
408+
if not is_auto_gptq_available() and not is_gptqmodel_available():
409+
raise RuntimeError(
410+
"gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`. Please notice that auto-gptq will be deprecated in the future."
411+
)
412+
elif is_gptqmodel_available() and is_auto_gptq_available():
413+
logger.warning(
414+
"Detected gptqmodel and auto-gptq, will use gptqmodel. The auto_gptq will be deprecated in the future."
415+
)
416+
417+
gptq_supports_cpu = (
418+
is_auto_gptq_available()
419+
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
420+
) or is_gptqmodel_available()
421+
422+
if not gptq_supports_cpu and not torch.cuda.is_available():
423+
raise RuntimeError(
424+
"No cuda gpu or cpu support using Intel/IPEX found. A gpu or cpu with Intel/IPEX is required for quantization."
425+
)
426+
427+
if not self.sym and not is_gptqmodel_available():
428+
raise ValueError(
429+
"Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`"
430+
)
431+
432+
if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available():
433+
raise ValueError(
434+
"gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`"
435+
)
325436

326437
model.eval()
327438

439+
# gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True
440+
if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2":
441+
self.checkpoint_format = "gptq_v2"
442+
328443
# For Transformer model
329444
has_config = False
330445
has_device_map = False
@@ -403,39 +518,40 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
403518

404519
blocks = recurse_getattr(model, self.block_name_to_quantize)
405520

521+
cur_layer_device = get_device(blocks[0])
522+
if not is_gptqmodel_available():
523+
cur_layer_device = 0
524+
406525
if not has_device_map:
407-
# put modules from module_name_preceding_first_block on cuda
526+
# put modules from module_name_preceding_first_block on cuda or xpu or cpu
527+
to_device = cur_layer_device
408528
for module_name in self.module_name_preceding_first_block:
409529
module = recurse_getattr(model, module_name)
410530
if module is None:
411531
raise ValueError(f"Module {module_name} was not found in model")
412-
module = module.to(0)
413-
blocks[0] = blocks[0].to(0)
532+
module = module.to(to_device)
533+
blocks[0] = blocks[0].to(to_device)
414534

415535
def store_input_hook(_, input, *args):
416536
kwargs = args[0]
417537
if input is None:
418538
if "hidden_states" in kwargs:
419-
input = (kwargs["hidden_states"],)
539+
input = (nested_move_to(kwargs["hidden_states"], cur_layer_device),)
420540
else:
421541
raise ValueError("No input value found in the foward pass")
422542
layer_inputs.append(input)
423543
other_kwargs = {}
424544
for k, v in kwargs.items(): # make sure other arguments also be captured
425545
if k not in ["hidden_states"]:
426-
other_kwargs[k] = v
546+
other_kwargs[k] = nested_move_to(v, cur_layer_device)
427547
layer_input_kwargs.append(other_kwargs)
428548
raise ValueError
429549

430550
if self.cache_block_outputs:
431551
handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
432552
for data in dataset:
433553
for k, v in data.items():
434-
# put the data on gpu, we won't put them back to cpu
435-
if not has_device_map or device.type == "cpu":
436-
data[k] = v.to(0)
437-
else:
438-
data[k] = v.to(device)
554+
data[k] = nested_move_to(v, cur_layer_device)
439555
try:
440556
model(**data)
441557
except ValueError:
@@ -450,6 +566,8 @@ def store_input_hook(_, input, *args):
450566
raise ValueError(f"Module {module_name} was not found in model")
451567

452568
torch.cuda.empty_cache()
569+
if hasattr(torch, "xpu") and torch.xpu.is_available():
570+
torch.xpu.empty_cache()
453571

454572
# Step 3: Quantize the blocks
455573
quantizers = {}
@@ -460,11 +578,7 @@ def store_input_hook(_, input, *args):
460578
handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True)
461579
for data in dataset:
462580
for k, v in data.items():
463-
# put the data on gpu, we won't put them back to cpu
464-
if not has_device_map or device.type == "cpu":
465-
data[k] = v.to(0)
466-
else:
467-
data[k] = v.to(device)
581+
data[k] = nested_move_to(v, cur_layer_device)
468582
try:
469583
model(**data)
470584
except ValueError:
@@ -473,9 +587,12 @@ def store_input_hook(_, input, *args):
473587

474588
# move block to cuda if needed
475589
# in case we have offload modules, we need to put them on cuda because of GPTQ object
476-
if not has_device_map or get_device(block) == torch.device("cpu"):
590+
if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu():
477591
block = block.to(0)
478592
layers = get_layers(block)
593+
block_device = get_device(block)
594+
if not is_gptqmodel_available():
595+
block_device = 0
479596
if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
480597
if self.true_sequential:
481598
layers_name_list = self.modules_in_block_to_quantize
@@ -509,15 +626,20 @@ def tmp(_, input, output):
509626
for j in range(len(dataset)):
510627
# the args are already on the gpu
511628
# don't need to store the output
629+
layer_inputs[j] = nested_move_to(layer_inputs[j], block_device)
630+
for k, v in layer_input_kwargs[j].items():
631+
layer_input_kwargs[j][k] = nested_move_to(v, block_device)
632+
512633
block(*layer_inputs[j], **layer_input_kwargs[j])
513634
# remove hook
514635
for h in handles:
515636
h.remove()
516637
for name in subset_name_list:
517638
logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...")
518-
scale, zero, g_idx = gptq[name].fasterquant(
639+
quant_outputs = gptq[name].fasterquant(
519640
percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act
520641
)
642+
scale, zero, g_idx = quant_outputs[0], quant_outputs[1], quant_outputs[2]
521643
quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = (
522644
gptq[name].quantizer,
523645
scale,
@@ -543,11 +665,13 @@ def tmp(_, input, output):
543665
del layer_inputs
544666
layer_inputs = []
545667
torch.cuda.empty_cache()
668+
if hasattr(torch, "xpu") and torch.xpu.is_available():
669+
torch.xpu.empty_cache()
546670

547671
if self.bits == 4:
548672
# device not on gpu
549673
if device.type != "cuda" or (has_device_map and any(d in devices for d in ["cpu", "disk", "hpu"])):
550-
if not self.disable_exllama:
674+
if not self.disable_exllama and not is_gptqmodel_available():
551675
logger.warning(
552676
"Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
553677
)
@@ -578,6 +702,8 @@ def tmp(_, input, output):
578702
model = self.post_init_model(model)
579703

580704
torch.cuda.empty_cache()
705+
if hasattr(torch, "xpu") and torch.xpu.is_available():
706+
torch.xpu.empty_cache()
581707
return model
582708

583709
def post_init_model(self, model):
@@ -601,9 +727,14 @@ def post_init_model(self, model):
601727
class StoreAttr(object):
602728
pass
603729

730+
if is_gptqmodel_available():
731+
model, _ = hf_convert_gptq_v1_to_v2_format(
732+
model, self.bits, self.quant_linear, self.checkpoint_format, self.meta
733+
)
734+
604735
model.quantize_config = StoreAttr()
605736
model.quantize_config.desc_act = self.desc_act
606-
model = autogptq_post_init(model, use_act_order=self.desc_act)
737+
model = gptq_post_init(model, use_act_order=self.desc_act)
607738
if (
608739
self.desc_act
609740
and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE)
@@ -626,19 +757,14 @@ def pack_model(
626757
quantizers (`Dict[str,Tuple]`):
627758
A mapping of the layer name and the data needed to pack the layer
628759
"""
629-
QuantLinear = dynamically_import_QuantLinear(
630-
use_triton=False,
631-
desc_act=self.desc_act,
632-
group_size=self.group_size,
633-
bits=self.bits,
634-
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
635-
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
636-
)
637760
logger.info("Packing model...")
638761
layers = get_layers(model)
639762
layers = {n: layers[n] for n in quantizers}
763+
764+
self.select_quant_linear(device_map=model.hf_device_map)
765+
640766
self._replace_by_quant_layers(model, quantizers)
641-
qlayers = get_layers(model, [QuantLinear])
767+
qlayers = get_layers(model, [self.quant_linear])
642768
for name in qlayers:
643769
logger.info(name)
644770
quantizers[name], scale, zero, g_idx = quantizers[name]
@@ -673,6 +799,15 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
673799
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
674800
675801
"""
802+
803+
# convert gptqmodel internal gptq_v2 format to v1 for max compatibility
804+
if is_gptqmodel_available():
805+
model, converted = hf_convert_gptq_v2_to_v1_format(
806+
model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta
807+
)
808+
if converted:
809+
self.checkpoint_format = "gptq"
810+
676811
os.makedirs(save_dir, exist_ok=True)
677812
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
678813
with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
@@ -736,10 +871,12 @@ def load_quantized_model(
736871
Returns:
737872
`nn.Module`: The quantized model
738873
"""
739-
if not torch.cuda.is_available():
740-
raise RuntimeError("No GPU found. A GPU is needed to run quantized model.")
741-
if not is_auto_gptq_available():
742-
raise RuntimeError("auto-gptq is required in order to load quantized weights : `pip install auto-gptq`")
874+
if not torch.cuda.is_available() and not is_gptqmodel_available():
875+
raise RuntimeError("No GPU found. A GPU is needed to run quantized model by auto_gptq.")
876+
if not is_auto_gptq_available() and not is_gptqmodel_available():
877+
raise RuntimeError(
878+
"gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) is required in order to load quantized weights. Please notice that auto-gptq will be deprecated in the future."
879+
)
743880
if not is_accelerate_available():
744881
raise RuntimeError(
745882
"You need to install accelerate in order to load and dispatch weights to"
@@ -777,7 +914,7 @@ def load_quantized_model(
777914
quantizer.exllama_version = quantizer.exllama_config["version"]
778915
quantizer.max_input_length = max_input_length
779916

780-
model = quantizer.convert_model(model)
917+
model = quantizer.convert_model(model, device_map=device_map)
781918

782919
if no_split_module_classes is None:
783920
no_split_module_classes = quantizer.get_no_split_module_classes(model)

‎optimum/gptq/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,18 @@ def get_seqlen(model: nn.Module):
113113
"We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`"
114114
)
115115
return 2048
116+
117+
118+
def move_to(obj: torch.Tensor, device: torch.device):
119+
if get_device(obj) != device:
120+
obj = obj.to(device)
121+
return obj
122+
123+
124+
def nested_move_to(v, device):
125+
if isinstance(v, torch.Tensor):
126+
return move_to(v, device)
127+
elif isinstance(v, (list, tuple)):
128+
return type(v)([nested_move_to(e, device) for e in v])
129+
else:
130+
return v

‎optimum/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
is_auto_gptq_available,
3838
is_datasets_available,
3939
is_diffusers_available,
40+
is_gptqmodel_available,
4041
is_onnx_available,
4142
is_onnxruntime_available,
4243
is_pydantic_available,

‎optimum/utils/import_utils.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
5252
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
5353
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
5454
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
55+
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2")
5556

5657

5758
# This is the minimal required version to support some ONNX Runtime features
@@ -67,6 +68,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
6768
_accelerate_available = _is_package_available("accelerate")
6869
_diffusers_available = _is_package_available("diffusers")
6970
_auto_gptq_available = _is_package_available("auto_gptq")
71+
_gptqmodel_available = _is_package_available("gptqmodel")
7072
_timm_available = _is_package_available("timm")
7173
_sentence_transformers_available = _is_package_available("sentence_transformers")
7274
_datasets_available = _is_package_available("datasets")
@@ -138,12 +140,23 @@ def is_datasets_available():
138140

139141
def is_auto_gptq_available():
140142
if _auto_gptq_available:
141-
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
142-
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
143+
v = version.parse(importlib_metadata.version("auto_gptq"))
144+
if v >= AUTOGPTQ_MINIMUM_VERSION:
143145
return True
144146
else:
145147
raise ImportError(
146-
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
148+
f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported"
149+
)
150+
151+
152+
def is_gptqmodel_available():
153+
if _gptqmodel_available:
154+
v = version.parse(importlib_metadata.version("gptqmodel"))
155+
if v >= GPTQMODEL_MINIMUM_VERSION:
156+
return True
157+
else:
158+
raise ImportError(
159+
f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported"
147160
)
148161

149162

0 commit comments

Comments
 (0)
Please sign in to comment.