Skip to content

Commit df72e9f

Browse files
added bert static test
1 parent b017856 commit df72e9f

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/neural_compressor/test_optimization.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
class QuantizationTest(INCTestMixin):
7373
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
7474
("text-classification", "bert", 21),
75-
# ("text-generation", "bloom", 21),
75+
("text-generation", "bloom", 21),
7676
)
7777

7878
SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
@@ -88,12 +88,14 @@ class QuantizationTest(INCTestMixin):
8888
@parameterized.expand(SUPPORTED_ARCHITECTURES_DYNAMIC)
8989
def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls):
9090
model_name = MODEL_NAMES[model_arch]
91-
quantization_config = PostTrainingQuantConfig(approach="dynamic")
9291
model_class = ORT_SUPPORTED_TASKS[task]["class"][0]
9392
tokenizer = AutoTokenizer.from_pretrained(model_name)
94-
save_onnx_model = False
93+
9594
quantized_model = None
95+
save_onnx_model = False
9696
model_kwargs = {"use_cache": False, "use_io_binding": False} if task == "text-generation" else {}
97+
quantization_config = PostTrainingQuantConfig(approach="dynamic")
98+
9799
with tempfile.TemporaryDirectory() as tmp_dir:
98100
for backend in ["torch", "ort"]:
99101
if backend == "torch":
@@ -104,8 +106,8 @@ def test_dynamic_quantization(self, task, model_arch, expected_quantized_matmuls
104106
quantizer = INCQuantizer.from_pretrained(model, task=task)
105107
quantizer.quantize(
106108
quantization_config=quantization_config,
107-
save_directory=tmp_dir,
108109
save_onnx_model=save_onnx_model,
110+
save_directory=tmp_dir,
109111
)
110112
if backend == "torch":
111113
quantized_model = quantizer._quantized_model
@@ -130,28 +132,29 @@ def test_static_quantization(self, task, model_arch, expected_quantized_matmuls)
130132
if tokenizer.pad_token is None:
131133
tokenizer.pad_token = tokenizer.eos_token
132134

135+
quantized_model = None
133136
save_onnx_model = False
134137
op_type_dict = (
135138
{"Embedding": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}}
136139
if save_onnx_model
137140
else None
138141
)
142+
model_kwargs = {"use_cache": False, "use_io_binding": False} if task == "text-generation" else {}
139143
quantization_config = PostTrainingQuantConfig(approach="static", op_type_dict=op_type_dict)
140-
quantized_model = None
141144

142145
with tempfile.TemporaryDirectory() as tmp_dir:
143146
for backend in ["torch", "ort"]:
144147
if backend == "torch":
145148
model = model_class.auto_model_class.from_pretrained(model_name)
146149
else:
147-
model = model_class.from_pretrained(model_name, export=True)
150+
model = model_class.from_pretrained(model_name, export=True, **model_kwargs)
148151
quantizer = INCQuantizer.from_pretrained(model, task=task)
149152
calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=num_samples)
150153
quantizer.quantize(
151154
quantization_config=quantization_config,
152155
calibration_dataset=calibration_dataset,
153-
save_directory=tmp_dir,
154156
save_onnx_model=save_onnx_model,
157+
save_directory=tmp_dir,
155158
)
156159
if backend == "torch":
157160
quantized_model = quantizer._quantized_model

0 commit comments

Comments
 (0)