70
70
71
71
72
72
class QuantizationTest (INCTestMixin ):
73
- SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
73
+ SUPPORTED_ARCHITECTURES_STATIC = (
74
+ ("text-generation" , "gpt_neo" , 17 ),
74
75
("text-classification" , "bert" , 21 ),
75
76
("text-generation" , "bloom" , 21 ),
76
77
)
77
78
78
- SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
79
+ SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_STATIC + (
79
80
("fill-mask" , "bert" , 22 ),
80
81
("token-classification" , "albert" , 26 ),
81
82
)
@@ -123,7 +124,7 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
123
124
load_inc_model = True ,
124
125
)
125
126
126
- @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS )
127
+ @parameterized .expand (SUPPORTED_ARCHITECTURES_STATIC )
127
128
def test_static_quantization (self , task , model_arch , expected_quantized_matmuls ):
128
129
num_samples = 10
129
130
model_name = MODEL_NAMES [model_arch ]
@@ -134,22 +135,19 @@ def test_static_quantization(self, task, model_arch, expected_quantized_matmuls)
134
135
135
136
quantized_model = None
136
137
save_onnx_model = False
137
- op_type_dict = (
138
- {"Embedding" : {"weight" : {"dtype" : ["fp32" ]}, "activation" : {"dtype" : ["fp32" ]}}}
139
- if save_onnx_model
140
- else None
141
- )
138
+ quantization_config = PostTrainingQuantConfig (approach = "static" )
142
139
model_kwargs = {"use_cache" : False , "use_io_binding" : False } if task == "text-generation" else {}
143
- quantization_config = PostTrainingQuantConfig (approach = "static" , op_type_dict = op_type_dict )
144
140
145
141
with tempfile .TemporaryDirectory () as tmp_dir :
146
142
for backend in ["torch" , "ort" ]:
147
143
if backend == "torch" :
148
144
model = model_class .auto_model_class .from_pretrained (model_name )
149
145
else :
150
146
model = model_class .from_pretrained (model_name , export = True , ** model_kwargs )
147
+
151
148
quantizer = INCQuantizer .from_pretrained (model , task = task )
152
149
calibration_dataset = _generate_dataset (quantizer , tokenizer , num_samples = num_samples )
150
+
153
151
quantizer .quantize (
154
152
quantization_config = quantization_config ,
155
153
calibration_dataset = calibration_dataset ,
0 commit comments