Skip to content

Commit 53240c3

Browse files
Allow GPTQModel to auto select Marlin or faster kernels for inference only ops (#2138)
* select quant_linear with pack * up GPTQMODEL_MINIMUM_VERSION * Update quantizer.py * update gptqmodel version --------- Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
1 parent 72498dd commit 53240c3

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

optimum/gptq/quantizer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __init__(
220220
)
221221
self.exllama_version = self.exllama_config["version"]
222222

223-
def select_quant_linear(self, device_map: Union[str, dict]):
223+
def select_quant_linear(self, device_map: Union[str, dict], pack: bool = False):
224224
if is_gptqmodel_available():
225225
self.quant_linear = hf_select_quant_linear(
226226
bits=self.bits,
@@ -231,6 +231,7 @@ def select_quant_linear(self, device_map: Union[str, dict]):
231231
meta=self.meta,
232232
device_map=device_map,
233233
backend=self.backend,
234+
pack=pack,
234235
)
235236
else:
236237
self.quant_linear = hf_select_quant_linear(
@@ -301,7 +302,7 @@ def convert_model(self, model: nn.Module, **kwargs):
301302
)
302303
del layers_to_be_replaced[name]
303304

304-
self.select_quant_linear(device_map=kwargs.get("device_map", None))
305+
self.select_quant_linear(device_map=kwargs.get("device_map", None), pack=False)
305306

306307
self._replace_by_quant_layers(model, layers_to_be_replaced)
307308

@@ -761,7 +762,7 @@ def pack_model(
761762
layers = get_layers(model)
762763
layers = {n: layers[n] for n in quantizers}
763764

764-
self.select_quant_linear(device_map=model.hf_device_map)
765+
self.select_quant_linear(device_map=model.hf_device_map, pack=True)
765766

766767
self._replace_by_quant_layers(model, quantizers)
767768
qlayers = get_layers(model, [self.quant_linear])

optimum/utils/import_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
5252
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
5353
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
5454
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
55-
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2")
55+
GPTQMODEL_MINIMUM_VERSION = version.parse("1.6.0")
5656

5757

5858
# This is the minimal required version to support some ONNX Runtime features

0 commit comments

Comments
 (0)