@@ -227,6 +227,34 @@ def _from_transformers(
227
227
if use_cache :
228
228
task = task + "-with-past"
229
229
230
+ # Patch the modules to export of GPTQ models w/o GPU
231
+ do_gptq_patching = False
232
+ config_dict = config .to_dict ()
233
+ quantization_config = config_dict .get ("quantization_config" , None )
234
+ do_gptq_patching = quantization_config and quantization_config ["quant_method" ] == "gptq"
235
+ if do_gptq_patching :
236
+ torch .set_default_dtype (torch .float32 )
237
+ orig_cuda_check = torch .cuda .is_available
238
+ torch .cuda .is_available = lambda : True
239
+
240
+ from optimum .gptq import GPTQQuantizer
241
+
242
+ orig_post_init_model = GPTQQuantizer .post_init_model
243
+
244
+ def post_init_model (self , model ):
245
+ from auto_gptq import exllama_set_max_input_length
246
+
247
+ class StoreAttr (object ):
248
+ pass
249
+
250
+ model .quantize_config = StoreAttr ()
251
+ model .quantize_config .desc_act = self .desc_act
252
+ if self .desc_act and not self .disable_exllama and self .max_input_length is not None :
253
+ model = exllama_set_max_input_length (model , self .max_input_length )
254
+ return model
255
+
256
+ GPTQQuantizer .post_init_model = post_init_model
257
+
230
258
main_export (
231
259
model_name_or_path = model_id ,
232
260
output = save_dir_path ,
@@ -238,10 +266,14 @@ def _from_transformers(
238
266
local_files_only = local_files_only ,
239
267
force_download = force_download ,
240
268
trust_remote_code = trust_remote_code ,
241
- model_kwargs = kwargs ,
242
269
int8 = load_in_8bit ,
243
270
)
244
271
272
+ # Unpatch modules after GPTQ export
273
+ if do_gptq_patching :
274
+ torch .cuda .is_available = orig_cuda_check
275
+ GPTQQuantizer .post_init_model = orig_post_init_model
276
+
245
277
config .is_decoder = True
246
278
config .is_encoder_decoder = False
247
279
config .save_pretrained (save_dir_path )
@@ -320,7 +352,10 @@ def forward(
320
352
input_ids = input_ids [:, - 1 :]
321
353
322
354
inputs = {}
355
+ past_len = 0
323
356
if past_key_values is not None :
357
+ seq_len_dim = 1 if self .model .input (self .key_value_input_names [0 ]).get_partial_shape ()[1 ].is_dynamic else 2
358
+ past_len = past_key_values [0 ][0 ].shape [seq_len_dim ]
324
359
if self ._pkv_precision == Type .bf16 :
325
360
# numpy does not support bf16, pretending f16, should change to bf16
326
361
past_key_values = tuple (
@@ -355,8 +390,13 @@ def forward(
355
390
inputs ["input_ids" ] = np .array (input_ids )
356
391
357
392
# Add the attention_mask inputs when needed
358
- if "attention_mask" in self .input_names and attention_mask is not None :
359
- inputs ["attention_mask" ] = np .array (attention_mask )
393
+ if "attention_mask" in self .input_names :
394
+ if attention_mask is not None :
395
+ inputs ["attention_mask" ] = np .array (attention_mask )
396
+ else :
397
+ inputs ["attention_mask" ] = np .ones (
398
+ (input_ids .shape [0 ], input_ids .shape [1 ] + past_len ), dtype = inputs ["input_ids" ].dtype
399
+ )
360
400
361
401
# Run inference
362
402
self .request .start_async (inputs , shared_memory = True )
0 commit comments