@@ -52,7 +52,7 @@ def init_scales_from_module_config(self, module):
52
52
for idx , output in enumerate (module .outputs ):
53
53
if self .output_scales_creators [idx ].scale is None :
54
54
self .output_scales_creators [idx ].scale = output
55
-
55
+
56
56
def calc_input_scales (self , num_of_inputs ):
57
57
input_scales = []
58
58
for i in range (num_of_inputs ):
@@ -96,7 +96,7 @@ def __init__(self, config, mod, measurement, params, module_type):
96
96
def get_scales_module_config (self ):
97
97
input_scales = self .calc_input_scales (num_of_inputs = 1 )
98
98
output_measurement = self .measurement .outputs [0 ] if self .measurement is not None else []
99
- rescaled_weight = self .mod .weight
99
+ rescaled_weight = self .mod .weight if hasattr ( self . mod , 'weight' ) else None
100
100
if self .weight_ich_scale_calc is not None :
101
101
weight_scales_in_ch = self .weight_ich_scale_calc .calc_scales (input_scales [0 ], QuantTensorType .CONST )
102
102
rescaled_weight = scale_fcn (self .mod .weight , weight_scales_in_ch .reshape ([1 , - 1 ]))
@@ -265,7 +265,12 @@ class DynamicMoeOpQuantizer(BaseOpQuantizer):
265
265
266
266
def __init__ (self , config , mod , measurement , params , module_type ):
267
267
super ().__init__ (config , mod , measurement , params , module_type )
268
- self .inputs_scales_creators = [self .scales_method_factory .get_scale_method (QuantTensorName .INPUT ) for i in range (len (measurement .inputs ) + mod .num_experts )]
268
+ num_of_inputs = len (self .measurement .inputs ) if self .measurement is not None else 1
269
+ num_of_experts = self .mod .num_experts if self .mod .num_experts is not None else 8
270
+ self .inputs_scales_creators = [
271
+ self .scales_method_factory .get_scale_method (QuantTensorName .INPUT )
272
+ for i in range (num_of_inputs + num_of_experts )
273
+ ]
269
274
self .output_scales_creators .append (self .scales_method_factory .get_scale_method (QuantTensorName .OUTPUT ))
270
275
271
276
def get_scales_module_config (self ):
@@ -304,4 +309,3 @@ def scales_module_config_to_q_and_dq(self, module):
304
309
305
310
def get_op_quantizer (module_type , config , mod , measurement , params ):
306
311
return ops_quantizer_map [module_type ](config , mod , measurement , params , module_type )
307
-
0 commit comments