|
27 | 27 | import numpy as np
|
28 | 28 | import torch
|
29 | 29 | from datasets import load_dataset
|
| 30 | +from nncf.common.logging.track_progress import track |
30 | 31 | from nncf.quantization.advanced_parameters import OverflowFix
|
31 | 32 | from parameterized import parameterized
|
32 | 33 | import openvino.runtime as ov
|
@@ -221,7 +222,7 @@ class OVWeightCompressionTest(unittest.TestCase):
|
221 | 222 | ),
|
222 | 223 | (
|
223 | 224 | OVModelForCausalLM,
|
224 |
| - "HuggingFaceH4/tiny-random-LlamaForCausalLM", |
| 225 | + "llama_awq", |
225 | 226 | dict(
|
226 | 227 | bits=4,
|
227 | 228 | sym=True,
|
@@ -448,22 +449,30 @@ def test_ovmodel_4bit_auto_compression(self, model_cls, model_type, expected_ov_
|
448 | 449 | def test_ovmodel_4bit_auto_compression_with_config(
|
449 | 450 | self, model_cls, model_name, quantization_config, expected_ov_int4
|
450 | 451 | ):
|
| 452 | + # If this variable is defined locally, collect_descriptions() for some reason will collect values to the list |
| 453 | + # defined for the first test case |
| 454 | + if "track_descriptions" not in globals(): |
| 455 | + globals()["track_descriptions"] = [] |
| 456 | + track_descriptions = globals()["track_descriptions"] |
| 457 | + track_descriptions.clear() |
| 458 | + |
| 459 | + def collect_descriptions(*args, **kwargs): |
| 460 | + track_descriptions.append(kwargs["description"]) |
| 461 | + return unittest.mock.DEFAULT |
| 462 | + |
451 | 463 | model_id = MODEL_NAMES[model_name]
|
452 | 464 | with tempfile.TemporaryDirectory() as tmp_dir:
|
453 | 465 | quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
|
454 | 466 |
|
455 |
| - from nncf.common.logging.track_progress import track |
456 |
| - |
457 |
| - with unittest.mock.patch("nncf.common.logging.track_progress.track", wraps=track) as track_patch: |
| 467 | + with unittest.mock.patch( |
| 468 | + "nncf.common.logging.track_progress.track", |
| 469 | + wraps=track, |
| 470 | + side_effect=collect_descriptions |
| 471 | + ): |
458 | 472 | model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)
|
459 | 473 | if quantization_config.quant_method == QuantizationMethod.AWQ:
|
460 | 474 | # Called at least once with description="Applying AWQ"
|
461 |
| - self.assertTrue( |
462 |
| - any( |
463 |
| - args.kwargs.get("description", None) == "Applying AWQ" |
464 |
| - for args in track_patch.call_args_list |
465 |
| - ) |
466 |
| - ) |
| 475 | + self.assertTrue(any(it == "Applying AWQ" for it in track_descriptions)) |
467 | 476 |
|
468 | 477 | tokenizer = AutoTokenizer.from_pretrained(model_id)
|
469 | 478 | if tokenizer.pad_token is None:
|
|
0 commit comments