Skip to content

Commit 261c2b2

Browse files
linoybuxinhe3
authored andcommitted
Fix bug in mixtral unitscale (#141)
* [SW-218197] fix bug in Mixtral unitscale * [SW-218197] fix bug in Mixtral unitscale
1 parent 7be8dd2 commit 261c2b2

File tree

1 file changed

+8
-4
lines changed
  • neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods

1 file changed

+8
-4
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def init_scales_from_module_config(self, module):
5252
for idx, output in enumerate(module.outputs):
5353
if self.output_scales_creators[idx].scale is None:
5454
self.output_scales_creators[idx].scale = output
55-
55+
5656
def calc_input_scales(self, num_of_inputs):
5757
input_scales = []
5858
for i in range(num_of_inputs):
@@ -96,7 +96,7 @@ def __init__(self, config, mod, measurement, params, module_type):
9696
def get_scales_module_config(self):
9797
input_scales = self.calc_input_scales(num_of_inputs=1)
9898
output_measurement = self.measurement.outputs[0] if self.measurement is not None else []
99-
rescaled_weight = self.mod.weight
99+
rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None
100100
if self.weight_ich_scale_calc is not None:
101101
weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST)
102102
rescaled_weight = scale_fcn(self.mod.weight, weight_scales_in_ch.reshape([1, -1]))
@@ -265,7 +265,12 @@ class DynamicMoeOpQuantizer(BaseOpQuantizer):
265265

266266
def __init__(self, config, mod, measurement, params, module_type):
267267
super().__init__(config, mod, measurement, params, module_type)
268-
self.inputs_scales_creators = [self.scales_method_factory.get_scale_method(QuantTensorName.INPUT) for i in range(len(measurement.inputs) + mod.num_experts)]
268+
num_of_inputs = len(self.measurement.inputs) if self.measurement is not None else 1
269+
num_of_experts = self.mod.num_experts if self.mod.num_experts is not None else 8
270+
self.inputs_scales_creators = [
271+
self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)
272+
for i in range(num_of_inputs + num_of_experts)
273+
]
269274
self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT))
270275

271276
def get_scales_module_config(self):
@@ -304,4 +309,3 @@ def scales_module_config_to_q_and_dq(self, module):
304309

305310
def get_op_quantizer(module_type, config, mod, measurement, params):
306311
return ops_quantizer_map[module_type](config, mod, measurement, params, module_type)
307-

0 commit comments

Comments
 (0)