@@ -160,6 +160,7 @@ def main_export(
160
160
)
161
161
convert_tokenizer = False
162
162
163
+ do_gptq_patching = False
163
164
custom_architecture = False
164
165
loading_kwargs = {}
165
166
if library_name == "transformers" :
@@ -173,6 +174,8 @@ def main_export(
173
174
force_download = force_download ,
174
175
trust_remote_code = trust_remote_code ,
175
176
)
177
+ quantization_config = getattr (config , "quantization_config" , None )
178
+ do_gptq_patching = quantization_config and quantization_config ["quant_method" ] == "gptq"
176
179
model_type = config .model_type .replace ("_" , "-" )
177
180
178
181
if model_type not in TasksManager ._SUPPORTED_MODEL_TYPE :
@@ -193,6 +196,32 @@ def main_export(
193
196
if is_transformers_version (">=" , "4.36" ) and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED :
194
197
loading_kwargs ["attn_implementation" ] = "eager"
195
198
199
+ # Patch the modules to export of GPTQ models w/o GPU
200
+ if do_gptq_patching :
201
+ import torch
202
+
203
+ torch .set_default_dtype (torch .float32 )
204
+ orig_cuda_check = torch .cuda .is_available
205
+ torch .cuda .is_available = lambda : True
206
+
207
+ from optimum .gptq import GPTQQuantizer
208
+
209
+ orig_post_init_model = GPTQQuantizer .post_init_model
210
+
211
+ def post_init_model (self , model ):
212
+ from auto_gptq import exllama_set_max_input_length
213
+
214
+ class StoreAttr (object ):
215
+ pass
216
+
217
+ model .quantize_config = StoreAttr ()
218
+ model .quantize_config .desc_act = self .desc_act
219
+ if self .desc_act and not self .disable_exllama and self .max_input_length is not None :
220
+ model = exllama_set_max_input_length (model , self .max_input_length )
221
+ return model
222
+
223
+ GPTQQuantizer .post_init_model = post_init_model
224
+
196
225
model = TasksManager .get_model_from_task (
197
226
task ,
198
227
model_name_or_path ,
@@ -295,3 +324,8 @@ def main_export(
295
324
tokenizer_2 = getattr (model , "tokenizer_2" , None )
296
325
if tokenizer_2 is not None :
297
326
export_tokenizer (tokenizer_2 , output , suffix = "_2" )
327
+
328
+ # Unpatch modules after GPTQ export
329
+ if do_gptq_patching :
330
+ torch .cuda .is_available = orig_cuda_check
331
+ GPTQQuantizer .post_init_model = orig_post_init_model
0 commit comments