70
70
71
71
72
72
class QuantizationTest (INCTestMixin ):
73
- SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
73
+ SUPPORTED_ARCHITECTURES_STATIC = (
74
+ ("text-generation" , "gpt_neo" , 17 ),
74
75
("text-classification" , "bert" , 21 ),
75
- # ("text-generation", "bloom", 21),
76
+ ("text-generation" , "bloom" , 21 ),
76
77
)
77
78
78
- SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
79
+ SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_STATIC + (
79
80
("fill-mask" , "bert" , 22 ),
80
81
("token-classification" , "albert" , 26 ),
81
82
)
@@ -88,12 +89,14 @@ class QuantizationTest(INCTestMixin):
88
89
@parameterized .expand (SUPPORTED_ARCHITECTURES_DYNAMIC )
89
90
def test_dynamic_quantization (self , task , model_arch , expected_quantized_matmuls ):
90
91
model_name = MODEL_NAMES [model_arch ]
91
- quantization_config = PostTrainingQuantConfig (approach = "dynamic" )
92
92
model_class = ORT_SUPPORTED_TASKS [task ]["class" ][0 ]
93
93
tokenizer = AutoTokenizer .from_pretrained (model_name )
94
- save_onnx_model = False
94
+
95
95
quantized_model = None
96
+ save_onnx_model = False
96
97
model_kwargs = {"use_cache" : False , "use_io_binding" : False } if task == "text-generation" else {}
98
+ quantization_config = PostTrainingQuantConfig (approach = "dynamic" )
99
+
97
100
with tempfile .TemporaryDirectory () as tmp_dir :
98
101
for backend in ["torch" , "ort" ]:
99
102
if backend == "torch" :
@@ -104,8 +107,8 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
104
107
quantizer = INCQuantizer .from_pretrained (model , task = task )
105
108
quantizer .quantize (
106
109
quantization_config = quantization_config ,
107
- save_directory = tmp_dir ,
108
110
save_onnx_model = save_onnx_model ,
111
+ save_directory = tmp_dir ,
109
112
)
110
113
if backend == "torch" :
111
114
quantized_model = quantizer ._quantized_model
@@ -121,7 +124,7 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
121
124
load_inc_model = True ,
122
125
)
123
126
124
- @parameterized .expand (SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS )
127
+ @parameterized .expand (SUPPORTED_ARCHITECTURES_STATIC )
125
128
def test_static_quantization (self , task , model_arch , expected_quantized_matmuls ):
126
129
num_samples = 10
127
130
model_name = MODEL_NAMES [model_arch ]
@@ -130,28 +133,26 @@ def test_static_quantization(self, task, model_arch, expected_quantized_matmuls)
130
133
if tokenizer .pad_token is None :
131
134
tokenizer .pad_token = tokenizer .eos_token
132
135
133
- save_onnx_model = False
134
- op_type_dict = (
135
- {"Embedding" : {"weight" : {"dtype" : ["fp32" ]}, "activation" : {"dtype" : ["fp32" ]}}}
136
- if save_onnx_model
137
- else None
138
- )
139
- quantization_config = PostTrainingQuantConfig (approach = "static" , op_type_dict = op_type_dict )
140
136
quantized_model = None
137
+ save_onnx_model = False
138
+ quantization_config = PostTrainingQuantConfig (approach = "static" )
139
+ model_kwargs = {"use_cache" : False , "use_io_binding" : False } if task == "text-generation" else {}
141
140
142
141
with tempfile .TemporaryDirectory () as tmp_dir :
143
142
for backend in ["torch" , "ort" ]:
144
143
if backend == "torch" :
145
144
model = model_class .auto_model_class .from_pretrained (model_name )
146
145
else :
147
- model = model_class .from_pretrained (model_name , export = True )
146
+ model = model_class .from_pretrained (model_name , export = True , ** model_kwargs )
147
+
148
148
quantizer = INCQuantizer .from_pretrained (model , task = task )
149
149
calibration_dataset = _generate_dataset (quantizer , tokenizer , num_samples = num_samples )
150
+
150
151
quantizer .quantize (
151
152
quantization_config = quantization_config ,
152
153
calibration_dataset = calibration_dataset ,
153
- save_directory = tmp_dir ,
154
154
save_onnx_model = save_onnx_model ,
155
+ save_directory = tmp_dir ,
155
156
)
156
157
if backend == "torch" :
157
158
quantized_model = quantizer ._quantized_model
0 commit comments