Skip to content

Commit 2b220bb

Browse files
committed
add xfail for torchvision
Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent 572f7c0 commit 2b220bb

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

test/3x/torch/quantization/fp8_quant/test_fp8_static_quant.py

+33-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55
import torch
6-
import torchvision
76
import transformers
87

98
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedConv2d, PatchedLinear
@@ -36,14 +35,12 @@ def calib_func(model):
3635

3736

3837
@pytest.mark.skipif(not is_hpex_available(), reason="HPU environment is required!")
39-
class TestFP8StaticQuant:
38+
class TestFP8StaticQuantNLP:
4039
def setup_class(self):
4140
change_to_cur_file_dir()
4241
config = transformers.AutoConfig.from_pretrained("./model_configs/tiny_gptj.json")
4342
self.tiny_gptj = transformers.AutoModelForCausalLM.from_config(config)
4443
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long).to("hpu")
45-
self.resnet18 = torchvision.models.resnet18()
46-
self.cv_dummy_inputs = torch.randn([1, 3, 224, 224]).to("hpu")
4744

4845
def teardown_class(self):
4946
shutil.rmtree("test_ouputs", ignore_errors=True)
@@ -72,6 +69,38 @@ def test_one_step_quant_nlp(self):
7269
), "k_proj input dtype is not torch.float8_e4m3fn."
7370
assert (fp32_out != fp8_out).any(), "FP32 output should be different with FP8 output"
7471

72+
@torch.no_grad()
73+
def test_two_step_quant_nlp(self):
74+
# step 1: measurement
75+
model = copy.deepcopy(self.tiny_gptj)
76+
config = FP8Config.from_json_file("test_fp8_jsons/test_measure.json")
77+
model = prepare(model, config)
78+
calib_func(model)
79+
finalize_calibration(model)
80+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not observed."
81+
# step 2: quantize based on measurement
82+
model = copy.deepcopy(self.tiny_gptj)
83+
config = FP8Config.from_json_file("test_fp8_jsons/test_hw_quant.json")
84+
model = convert(model, config)
85+
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not quantized."
86+
assert (
87+
model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn
88+
), "k_proj input dtype is not torch.float8_e4m3fn."
89+
90+
91+
@pytest.mark.xfail(reason="[SW-219514] RuntimeError: operator torchvision::nms does not exist")
92+
@pytest.mark.skipif(not is_hpex_available(), reason="HPU environment is required!")
93+
class TestFP8StaticQuantCV:
94+
def setup_class(self):
95+
change_to_cur_file_dir()
96+
import torchvision
97+
self.resnet18 = torchvision.models.resnet18()
98+
self.cv_dummy_inputs = torch.randn([1, 3, 224, 224]).to("hpu")
99+
100+
def teardown_class(self):
101+
shutil.rmtree("test_ouputs", ignore_errors=True)
102+
shutil.rmtree("saved_results", ignore_errors=True)
103+
75104
@torch.no_grad()
76105
def test_one_step_quant_cv(self):
77106
model = copy.deepcopy(self.resnet18)
@@ -94,24 +123,6 @@ def test_one_step_quant_cv(self):
94123
), "model is not quantized to torch.float8_e4m3fn."
95124
assert (fp32_out != fp8_out).any(), "FP32 output should be different with FP8 output"
96125

97-
@torch.no_grad()
98-
def test_two_step_quant_nlp(self):
99-
# step 1: measurement
100-
model = copy.deepcopy(self.tiny_gptj)
101-
config = FP8Config.from_json_file("test_fp8_jsons/test_measure.json")
102-
model = prepare(model, config)
103-
calib_func(model)
104-
finalize_calibration(model)
105-
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not observed."
106-
# step 2: quantize based on measurement
107-
model = copy.deepcopy(self.tiny_gptj)
108-
config = FP8Config.from_json_file("test_fp8_jsons/test_hw_quant.json")
109-
model = convert(model, config)
110-
assert isinstance(model.transformer.h[0].attn.k_proj, PatchedLinear), "k_proj is not quantized."
111-
assert (
112-
model.transformer.h[0].attn.k_proj.quant_input.lp_dtype == torch.float8_e4m3fn
113-
), "k_proj input dtype is not torch.float8_e4m3fn."
114-
115126
@torch.no_grad()
116127
def test_two_step_quant_cv(self):
117128
# step 1: measurement

0 commit comments

Comments
 (0)