Skip to content

Commit f9b30c1

Browse files
committed
revert tests
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 7312b7a commit f9b30c1

File tree

2 files changed

+51
-97
lines changed

2 files changed

+51
-97
lines changed

optimum/utils/testing_utils.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
is_auto_gptq_available,
3131
is_datasets_available,
3232
is_diffusers_available,
33-
is_gptqmodel_available,
3433
is_sentence_transformers_available,
3534
is_timm_available,
3635
)
@@ -61,13 +60,11 @@ def require_accelerate(test_case):
6160
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
6261

6362

64-
def require_gptq(test_case):
63+
def require_auto_gptq(test_case):
6564
"""
66-
Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
65+
Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed.
6766
"""
68-
return unittest.skipUnless(
69-
is_auto_gptq_available() or is_gptqmodel_available(), "test requires gptqmodel or auto-gptq"
70-
)(test_case)
67+
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
7168

7269

7370
def require_torch_gpu(test_case):

tests/gptq/test_quantization.py

+48-91
Original file line numberDiff line numberDiff line change
@@ -26,42 +26,38 @@
2626
from optimum.gptq.eval import evaluate_perplexity
2727
from optimum.gptq.utils import get_block_name_with_pattern, get_preceding_modules, get_seqlen
2828
from optimum.utils import recurse_getattr
29-
from optimum.utils.import_utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available
30-
from optimum.utils.testing_utils import require_gptq, require_torch_gpu
29+
from optimum.utils.import_utils import is_accelerate_available, is_auto_gptq_available
30+
from optimum.utils.testing_utils import require_auto_gptq, require_torch_gpu
3131

3232

3333
if is_auto_gptq_available():
3434
from auto_gptq import AutoGPTQForCausalLM
35-
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
36-
37-
if is_gptqmodel_available():
38-
from gptqmodel import GPTQModel
39-
from gptqmodel.utils.importer import hf_select_quant_linear
35+
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
4036

4137
if is_accelerate_available():
4238
from accelerate import init_empty_weights
4339

4440

4541
@slow
46-
@require_gptq
42+
@require_auto_gptq
43+
@require_torch_gpu
4744
class GPTQTest(unittest.TestCase):
48-
model_name = "Felladrin/Llama-68M-Chat-v1"
45+
model_name = "bigscience/bloom-560m"
4946

5047
expected_fp16_perplexity = 30
5148
expected_quantized_perplexity = 34
5249

53-
expected_compression_ratio = 1.2577
50+
expected_compression_ratio = 1.66
5451

5552
bits = 4
5653
group_size = 128
5754
desc_act = False
58-
sym = True
5955
disable_exllama = True
6056
exllama_config = None
6157
cache_block_outputs = True
6258
modules_in_block_to_quantize = None
63-
device_map_for_quantization = "cpu"
64-
device_for_inference = "cpu"
59+
device_map_for_quantization = "cuda"
60+
device_for_inference = 0
6561
dataset = [
6662
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
6763
]
@@ -74,7 +70,6 @@ def setUpClass(cls):
7470
"""
7571

7672
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
77-
cls.config = AutoConfig.from_pretrained(cls.model_name)
7873

7974
cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
8075
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization
@@ -89,7 +84,6 @@ def setUpClass(cls):
8984
dataset=cls.dataset,
9085
group_size=cls.group_size,
9186
desc_act=cls.desc_act,
92-
sym=cls.sym,
9387
disable_exllama=cls.disable_exllama,
9488
exllama_config=cls.exllama_config,
9589
cache_block_outputs=cls.cache_block_outputs,
@@ -110,51 +104,38 @@ def test_memory_footprint(self):
110104
self.assertAlmostEqual(self.fp16_mem / self.quantized_mem, self.expected_compression_ratio, places=2)
111105

112106
def test_perplexity(self):
113-
pass
107+
"""
108+
A simple test to check if the model conversion has been done correctly by checking on the
109+
the perplexity of the converted models
110+
"""
111+
112+
self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
113+
self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)
114114

115115
def test_quantized_layers_class(self):
116116
"""
117117
A simple test to check if the model conversion has been done correctly by checking on the
118118
the class type of the linear layers of the converted models
119119
"""
120120

121-
if is_gptqmodel_available():
122-
if hasattr(self.config, "quantization_config"):
123-
checkpoint_format = self.config.quantization_config.get("checkpoint_format")
124-
meta = self.config.quantization_config.get("meta")
125-
else:
126-
checkpoint_format = "gptq"
127-
meta = None
128-
QuantLinear = hf_select_quant_linear(
129-
bits=self.bits,
130-
group_size=self.group_size,
131-
desc_act=self.desc_act,
132-
sym=self.sym,
133-
device_map=self.device_map_for_quantization,
134-
checkpoint_format=checkpoint_format,
135-
meta=meta,
136-
)
137-
else:
138-
QuantLinear = hf_select_quant_linear(
139-
use_triton=False,
140-
desc_act=self.desc_act,
141-
group_size=self.group_size,
142-
bits=self.bits,
143-
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
144-
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
145-
)
146-
self.assertEqual(self.quantized_model.model.layers[0].mlp.gate_proj.__class__, QuantLinear)
121+
QuantLinear = dynamically_import_QuantLinear(
122+
use_triton=False,
123+
use_qigen=False,
124+
desc_act=self.desc_act,
125+
group_size=self.group_size,
126+
bits=self.bits,
127+
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
128+
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
129+
)
130+
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)
147131

148132
def check_quantized_layers_type(self, model, value):
149-
self.assertEqual(model.model.layers[0].mlp.gate_proj.QUANT_TYPE, value)
133+
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.QUANT_TYPE == value)
150134

151135
def test_serialization(self):
152136
"""
153137
Test the serialization of the model and the loading of the quantized weights
154138
"""
155-
# AutoGPTQ does not support CPU
156-
if self.device_map_for_quantization == "cpu" and not is_gptqmodel_available():
157-
return
158139

159140
with tempfile.TemporaryDirectory() as tmpdirname:
160141
self.quantizer.save(self.quantized_model, tmpdirname)
@@ -171,50 +152,33 @@ def test_serialization(self):
171152
disable_exllama=self.disable_exllama,
172153
exllama_config=self.exllama_config,
173154
)
174-
if is_auto_gptq_available() and not is_gptqmodel_available():
175-
quant_type = "cuda-old" if self.disable_exllama else "exllama"
155+
if self.disable_exllama:
156+
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
176157
else:
177-
quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "exllama"
178-
179-
self.check_quantized_layers_type(quantized_model_from_saved, quant_type)
158+
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
180159

181160
# transformers and auto-gptq compatibility
182161
# quantized models are more compatible with device map than
183162
# device context managers (they're never used in transformers testing suite)
184163
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
185-
if is_gptqmodel_available():
186-
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
187-
else:
188-
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
164+
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
189165

190166

191-
@require_torch_gpu
192-
class GPTQTestCUDA(GPTQTest):
193-
device_map_for_quantization = "cuda"
194-
device_for_inference = 0
195-
expected_compression_ratio = 1.2577
196-
expected_fp16_perplexity = 38
197-
expected_quantized_perplexity = 45
167+
class GPTQTestCPUInit(GPTQTest):
168+
device_map_for_quantization = "cpu"
198169

199170
def test_perplexity(self):
200-
"""
201-
A simple test to check if the model conversion has been done correctly by checking on the
202-
the perplexity of the converted models
203-
"""
204-
205-
self.assertLessEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
206-
self.assertLessEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)
171+
pass
207172

208173

209-
class GPTQTestExllama(GPTQTestCUDA):
174+
class GPTQTestExllama(GPTQTest):
210175
disable_exllama = False
211176
exllama_config = {"version": 1}
212177

213178

214-
class GPTQTestActOrder(GPTQTestCUDA):
179+
class GPTQTestActOrder(GPTQTest):
215180
disable_exllama = True
216181
desc_act = True
217-
expected_quantized_perplexity = 46
218182

219183
def test_serialization(self):
220184
# act_order don't work with qlinear_cuda kernel
@@ -245,10 +209,7 @@ def test_exllama_serialization(self):
245209
# quantized models are more compatible with device map than
246210
# device context managers (they're never used in transformers testing suite)
247211
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
248-
if is_gptqmodel_available():
249-
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
250-
else:
251-
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
212+
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
252213

253214
def test_exllama_max_input_length(self):
254215
"""
@@ -285,7 +246,7 @@ def test_exllama_max_input_length(self):
285246
quantized_model_from_saved.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
286247

287248

288-
class GPTQTestExllamav2(GPTQTestCUDA):
249+
class GPTQTestExllamav2(GPTQTest):
289250
desc_act = False
290251
disable_exllama = True
291252
exllama_config = {"version": 2}
@@ -298,6 +259,7 @@ def test_exllama_serialization(self):
298259
"""
299260
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
300261
"""
262+
301263
with tempfile.TemporaryDirectory() as tmpdirname:
302264
self.quantizer.save(self.quantized_model, tmpdirname)
303265
self.quantized_model.config.save_pretrained(tmpdirname)
@@ -311,36 +273,31 @@ def test_exllama_serialization(self):
311273
save_folder=tmpdirname,
312274
device_map={"": self.device_for_inference},
313275
)
314-
self.check_quantized_layers_type(
315-
quantized_model_from_saved, "exllama" if is_gptqmodel_available() else "exllamav2"
316-
)
276+
self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2")
317277

318278
# transformers and auto-gptq compatibility
319279
# quantized models are more compatible with device map than
320280
# device context managers (they're never used in transformers testing suite)
321281
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
322-
if is_gptqmodel_available():
323-
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
324-
else:
325-
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
282+
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
326283

327284

328-
class GPTQTestNoBlockCaching(GPTQTestCUDA):
285+
class GPTQTestNoBlockCaching(GPTQTest):
329286
cache_block_outputs = False
330287

331288

332-
class GPTQTestModuleQuant(GPTQTestCUDA):
289+
class GPTQTestModuleQuant(GPTQTest):
333290
# all layers are quantized apart from self_attention.dense
334291
modules_in_block_to_quantize = [
335-
["self_attn.q_proj"],
336-
["mlp.gate_proj"],
292+
["self_attention.query_key_value"],
293+
["mlp.dense_h_to_4h"],
294+
["mlp.dense_4h_to_h"],
337295
]
338-
expected_compression_ratio = 1.068
339-
expected_quantized_perplexity = 39
296+
expected_compression_ratio = 1.577
340297

341298
def test_not_converted_layers(self):
342299
# self_attention.dense should not be converted
343-
self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__, "Linear")
300+
self.assertTrue(self.quantized_model.transformer.h[0].self_attention.dense.__class__.__name__ == "Linear")
344301

345302

346303
class GPTQUtilsTest(unittest.TestCase):

0 commit comments

Comments
 (0)