@@ -220,7 +220,7 @@ def __init__(
220
220
)
221
221
self .exllama_version = self .exllama_config ["version" ]
222
222
223
- def select_quant_linear (self , device_map : Union [str , dict ]):
223
+ def select_quant_linear (self , device_map : Union [str , dict ], pack : bool = False ):
224
224
if is_gptqmodel_available ():
225
225
self .quant_linear = hf_select_quant_linear (
226
226
bits = self .bits ,
@@ -231,6 +231,7 @@ def select_quant_linear(self, device_map: Union[str, dict]):
231
231
meta = self .meta ,
232
232
device_map = device_map ,
233
233
backend = self .backend ,
234
+ pack = pack ,
234
235
)
235
236
else :
236
237
self .quant_linear = hf_select_quant_linear (
@@ -301,7 +302,7 @@ def convert_model(self, model: nn.Module, **kwargs):
301
302
)
302
303
del layers_to_be_replaced [name ]
303
304
304
- self .select_quant_linear (device_map = kwargs .get ("device_map" , None ))
305
+ self .select_quant_linear (device_map = kwargs .get ("device_map" , None ), pack = False )
305
306
306
307
self ._replace_by_quant_layers (model , layers_to_be_replaced )
307
308
@@ -761,7 +762,7 @@ def pack_model(
761
762
layers = get_layers (model )
762
763
layers = {n : layers [n ] for n in quantizers }
763
764
764
- self .select_quant_linear (device_map = model .hf_device_map )
765
+ self .select_quant_linear (device_map = model .hf_device_map , pack = True )
765
766
766
767
self ._replace_by_quant_layers (model , quantizers )
767
768
qlayers = get_layers (model , [self .quant_linear ])
0 commit comments