Skip to content

Commit 8273e7f

Browse files
authoredOct 18, 2023
GPTQ export w/a (#451)
* Patch the modules in order to export GPTQ models on CPU * Style
1 parent 65b20f5 commit 8273e7f

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed
 

‎optimum/intel/openvino/modeling_decoder.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,34 @@ def _from_transformers(
227227
if use_cache:
228228
task = task + "-with-past"
229229

230+
# Patch the modules to export of GPTQ models w/o GPU
231+
do_gptq_patching = False
232+
config_dict = config.to_dict()
233+
quantization_config = config_dict.get("quantization_config", None)
234+
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
235+
if do_gptq_patching:
236+
torch.set_default_dtype(torch.float32)
237+
orig_cuda_check = torch.cuda.is_available
238+
torch.cuda.is_available = lambda: True
239+
240+
from optimum.gptq import GPTQQuantizer
241+
242+
orig_post_init_model = GPTQQuantizer.post_init_model
243+
244+
def post_init_model(self, model):
245+
from auto_gptq import exllama_set_max_input_length
246+
247+
class StoreAttr(object):
248+
pass
249+
250+
model.quantize_config = StoreAttr()
251+
model.quantize_config.desc_act = self.desc_act
252+
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
253+
model = exllama_set_max_input_length(model, self.max_input_length)
254+
return model
255+
256+
GPTQQuantizer.post_init_model = post_init_model
257+
230258
main_export(
231259
model_name_or_path=model_id,
232260
output=save_dir_path,
@@ -238,10 +266,14 @@ def _from_transformers(
238266
local_files_only=local_files_only,
239267
force_download=force_download,
240268
trust_remote_code=trust_remote_code,
241-
model_kwargs=kwargs,
242269
int8=load_in_8bit,
243270
)
244271

272+
# Unpatch modules after GPTQ export
273+
if do_gptq_patching:
274+
torch.cuda.is_available = orig_cuda_check
275+
GPTQQuantizer.post_init_model = orig_post_init_model
276+
245277
config.is_decoder = True
246278
config.is_encoder_decoder = False
247279
config.save_pretrained(save_dir_path)

0 commit comments

Comments
 (0)
Please sign in to comment.