Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 471bee4

Browse files
committedFeb 9, 2024·
fix
1 parent 2e59197 commit 471bee4

File tree

2 files changed

+34
-39
lines changed

2 files changed

+34
-39
lines changed
 

‎optimum/exporters/openvino/__main__.py

+34
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def main_export(
160160
)
161161
convert_tokenizer = False
162162

163+
do_gptq_patching = False
163164
custom_architecture = False
164165
loading_kwargs = {}
165166
if library_name == "transformers":
@@ -173,6 +174,8 @@ def main_export(
173174
force_download=force_download,
174175
trust_remote_code=trust_remote_code,
175176
)
177+
quantization_config = getattr(config, "quantization_config", None)
178+
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
176179
model_type = config.model_type.replace("_", "-")
177180

178181
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
@@ -193,6 +196,32 @@ def main_export(
193196
if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
194197
loading_kwargs["attn_implementation"] = "eager"
195198

199+
# Patch the modules to export of GPTQ models w/o GPU
200+
if do_gptq_patching:
201+
import torch
202+
203+
torch.set_default_dtype(torch.float32)
204+
orig_cuda_check = torch.cuda.is_available
205+
torch.cuda.is_available = lambda: True
206+
207+
from optimum.gptq import GPTQQuantizer
208+
209+
orig_post_init_model = GPTQQuantizer.post_init_model
210+
211+
def post_init_model(self, model):
212+
from auto_gptq import exllama_set_max_input_length
213+
214+
class StoreAttr(object):
215+
pass
216+
217+
model.quantize_config = StoreAttr()
218+
model.quantize_config.desc_act = self.desc_act
219+
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
220+
model = exllama_set_max_input_length(model, self.max_input_length)
221+
return model
222+
223+
GPTQQuantizer.post_init_model = post_init_model
224+
196225
model = TasksManager.get_model_from_task(
197226
task,
198227
model_name_or_path,
@@ -295,3 +324,8 @@ def main_export(
295324
tokenizer_2 = getattr(model, "tokenizer_2", None)
296325
if tokenizer_2 is not None:
297326
export_tokenizer(tokenizer_2, output, suffix="_2")
327+
328+
# Unpatch modules after GPTQ export
329+
if do_gptq_patching:
330+
torch.cuda.is_available = orig_cuda_check
331+
GPTQQuantizer.post_init_model = orig_post_init_model

‎optimum/exporters/openvino/convert.py

-39
Original file line numberDiff line numberDiff line change
@@ -588,40 +588,6 @@ def export_from_model(
588588
else:
589589
model_type = model.config.model_type.replace("_", "-")
590590

591-
# Patch the modules to export of GPTQ models w/o GPU
592-
do_gptq_patching = False
593-
try:
594-
config_dict = model.config.to_dict()
595-
quantization_config = config_dict.get("quantization_config", None)
596-
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
597-
except Exception:
598-
pass
599-
600-
if do_gptq_patching:
601-
import torch
602-
603-
torch.set_default_dtype(torch.float32)
604-
orig_cuda_check = torch.cuda.is_available
605-
torch.cuda.is_available = lambda: True
606-
607-
from optimum.gptq import GPTQQuantizer
608-
609-
orig_post_init_model = GPTQQuantizer.post_init_model
610-
611-
def post_init_model(self, model):
612-
from auto_gptq import exllama_set_max_input_length
613-
614-
class StoreAttr(object):
615-
pass
616-
617-
model.quantize_config = StoreAttr()
618-
model.quantize_config.desc_act = self.desc_act
619-
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
620-
model = exllama_set_max_input_length(model, self.max_input_length)
621-
return model
622-
623-
GPTQQuantizer.post_init_model = post_init_model
624-
625591
custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
626592

627593
if task is not None:
@@ -756,11 +722,6 @@ class StoreAttr(object):
756722
model_kwargs=model_kwargs,
757723
)
758724

759-
# Unpatch modules after GPTQ export
760-
if do_gptq_patching:
761-
torch.cuda.is_available = orig_cuda_check
762-
GPTQQuantizer.post_init_model = orig_post_init_model
763-
764725

765726
def export_tokenizer(
766727
tokenizer,

0 commit comments

Comments
 (0)
Please sign in to comment.