@@ -201,6 +201,8 @@ def apply(
201
201
config = wp .compression_config
202
202
203
203
s , X = process_stats (statistics [k ], self ._subset_size )
204
+ s = s .astype (TensorDataType .float32 )
205
+ X = X .astype (TensorDataType .float32 )
204
206
205
207
top_k = max (int (s .shape [0 ] * self ._percent_to_apply ), 1 )
206
208
topk_idxs = fns .argsort (- s )[:top_k ]
@@ -218,6 +220,8 @@ def apply(
218
220
weight = self ._backend_entity .get_weight (
219
221
wp .node_with_weight , weight_port_id , model , graph
220
222
) # get_const_value(wp.weight_node)
223
+ weight_dtype = weight .dtype
224
+ weight = weight .astype (TensorDataType .float32 )
221
225
assert isinstance (wp .reduction_axes , tuple ) and len (wp .reduction_axes ) == 1
222
226
reduction_axis = wp .reduction_axes [0 ]
223
227
@@ -279,19 +283,19 @@ def apply(
279
283
w_scale = fns .unsqueeze (w_scale , 0 )
280
284
a_scale = fns .unsqueeze (1.0 / a_scale , 1 )
281
285
282
- scaled_weight = weight * w_scale
286
+ scaled_weight = ( weight * w_scale ). astype ( weight_dtype )
283
287
self ._backend_entity .set_weight (wp .node_with_weight , weight_port_id , model , graph , scaled_weight )
284
288
285
289
if self ._backend_entity .is_node_with_weights (
286
290
merge_node , graph
287
291
): # for MatMul->Multiply->MatMul pattern scale merged to first MatMul
288
292
for _ , port_id in self ._backend_entity .get_weight_names_and_port_ids (merge_node , graph ):
289
293
merge_weight = self ._backend_entity .get_weight (merge_node , port_id , model , graph )
290
- merge_weight = merge_weight * a_scale
294
+ merge_weight = ( merge_weight * a_scale ). astype ( weight_dtype )
291
295
self ._backend_entity .set_weight (merge_node , port_id , model , graph , merge_weight )
292
296
a_scale = fns .transpose (a_scale )
293
297
else : # for Act->Multiply->MatMul and Act->MatMul patterns scale inserted after Act as extra node
294
- a_scale = fns .transpose (a_scale )
298
+ a_scale = fns .transpose (a_scale ). astype ( weight_dtype )
295
299
next_nodes = graph .get_next_nodes (merge_node )
296
300
source_node_output_port = graph .get_output_edges (merge_node )[0 ].output_port_id
297
301
scale_insertion_command = self ._backend_entity .scale_insertion_command (
0 commit comments