@@ -227,6 +227,34 @@ def _from_transformers(
227
227
if use_cache :
228
228
task = task + "-with-past"
229
229
230
+ # Patch the modules to export of GPTQ models w/o GPU
231
+ do_gptq_patching = False
232
+ config_dict = config .to_dict ()
233
+ quantization_config = config_dict .get ("quantization_config" , None )
234
+ do_gptq_patching = quantization_config and quantization_config ["quant_method" ] == "gptq"
235
+ if do_gptq_patching :
236
+ torch .set_default_dtype (torch .float32 )
237
+ orig_cuda_check = torch .cuda .is_available
238
+ torch .cuda .is_available = lambda : True
239
+
240
+ from optimum .gptq import GPTQQuantizer
241
+
242
+ orig_post_init_model = GPTQQuantizer .post_init_model
243
+
244
+ def post_init_model (self , model ):
245
+ from auto_gptq import exllama_set_max_input_length
246
+
247
+ class StoreAttr (object ):
248
+ pass
249
+
250
+ model .quantize_config = StoreAttr ()
251
+ model .quantize_config .desc_act = self .desc_act
252
+ if self .desc_act and not self .disable_exllama and self .max_input_length is not None :
253
+ model = exllama_set_max_input_length (model , self .max_input_length )
254
+ return model
255
+
256
+ GPTQQuantizer .post_init_model = post_init_model
257
+
230
258
main_export (
231
259
model_name_or_path = model_id ,
232
260
output = save_dir_path ,
@@ -238,10 +266,14 @@ def _from_transformers(
238
266
local_files_only = local_files_only ,
239
267
force_download = force_download ,
240
268
trust_remote_code = trust_remote_code ,
241
- model_kwargs = kwargs ,
242
269
int8 = load_in_8bit ,
243
270
)
244
271
272
+ # Unpatch modules after GPTQ export
273
+ if do_gptq_patching :
274
+ torch .cuda .is_available = orig_cuda_check
275
+ GPTQQuantizer .post_init_model = orig_post_init_model
276
+
245
277
config .is_decoder = True
246
278
config .is_encoder_decoder = False
247
279
config .save_pretrained (save_dir_path )
0 commit comments