Skip to content

Commit 6a56169

Browse files
committed
use some OOP
1 parent 228dd37 commit 6a56169

File tree

4 files changed

+251
-198
lines changed

4 files changed

+251
-198
lines changed

tests/post_training/experimental/sparsify_activations/pipelines.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
3333
from tests.post_training.pipelines.base import PT_BACKENDS
3434
from tests.post_training.pipelines.base import BackendType
35-
from tests.post_training.pipelines.base import PTQTestPipeline
35+
from tests.post_training.pipelines.base import ErrorReason
36+
from tests.post_training.pipelines.base import ErrorReport
37+
from tests.post_training.pipelines.base import RunInfo
3638
from tests.post_training.pipelines.image_classification_timm import ImageClassificationTimm
3739
from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
3840
from tests.post_training.pipelines.lm_weight_compression import WCTimeStats
@@ -52,7 +54,18 @@ class SATimeStats(WCTimeStats):
5254
REGEX_PREFIX = [*WCTimeStats.REGEX_PREFIX, SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC]
5355

5456

55-
class SAPipelineMixin(PTQTestPipeline):
57+
@dataclass
58+
class SARunInfo(RunInfo):
59+
def get_result_dict(self):
60+
result = super().get_result_dict()
61+
result["Num FQ"] = self.num_compress_nodes.num_fq_nodes
62+
result["Num int4"] = self.num_compress_nodes.num_int4
63+
result["Num int8"] = self.num_compress_nodes.num_int8
64+
result["Num sparse activations"] = self.num_compress_nodes.num_sparse_activations
65+
return result
66+
67+
68+
class SAPipelineMixin(LMWeightCompression):
5669
"""
5770
Common methods in the test pipeline for Sparsify Activations.
5871
"""
@@ -88,8 +101,24 @@ def get_num_compressed(self) -> None:
88101
model = ie.read_model(model=self.path_compressed_ir)
89102
self.run_info.num_compress_nodes.num_sparse_activations = count_sparsifier_patterns_in_ov(model)
90103

104+
def collect_errors(self) -> List[ErrorReport]:
105+
errors = super().collect_errors()
106+
run_info = self.run_info
107+
reference_data = self.reference_data
108+
109+
ref_num_sparse_activations = reference_data.get("num_sparse_activations")
110+
num_sparse_activations = run_info.num_compress_nodes.num_sparse_activations
111+
112+
if ref_num_sparse_activations is not None and num_sparse_activations != ref_num_sparse_activations:
113+
status_msg = (
114+
f"Regression: The number of sparse activations is {num_sparse_activations}, "
115+
f"which differs from reference {ref_num_sparse_activations}."
116+
)
117+
errors.append(ErrorReport(ErrorReason.NUM_COMPRESSED, status_msg))
118+
return errors
119+
91120

92-
class LMSparsifyActivations(SAPipelineMixin, LMWeightCompression):
121+
class LMSparsifyActivations(SAPipelineMixin):
93122
DEFAULT_SUBSET_SIZE = 32
94123

95124
def prepare_model(self):

tests/post_training/pipelines/base.py

+117-29
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
from datetime import timedelta
2020
from enum import Enum
2121
from pathlib import Path
22-
from typing import Dict, Optional
22+
from typing import Dict, List, Optional
2323

24+
import numpy as np
2425
import onnx
2526
import openvino as ov
2627
import torch
@@ -179,7 +180,8 @@ def format_memory_usage(memory):
179180
return None
180181
return int(memory)
181182

182-
def get_result_dict(self):
183+
def get_result_dict(self) -> Dict[str, str]:
184+
"""Returns a dictionary with the results of the run."""
183185
ram_data = {}
184186
if self.compression_memory_usage_rss is None and self.compression_memory_usage_system is None:
185187
ram_data["RAM MiB"] = self.format_memory_usage(self.compression_memory_usage)
@@ -194,10 +196,6 @@ def get_result_dict(self):
194196
"Metric name": self.metric_name,
195197
"Metric value": self.metric_value,
196198
"Metric diff": self.metric_diff,
197-
"Num FQ": self.num_compress_nodes.num_fq_nodes,
198-
"Num int4": self.num_compress_nodes.num_int4,
199-
"Num int8": self.num_compress_nodes.num_int8,
200-
"Num sparse activations": self.num_compress_nodes.num_sparse_activations,
201199
"Compr. time": self.format_time(self.time_compression),
202200
**self.stats_from_output.get_stats(),
203201
"Total time": self.format_time(self.time_total),
@@ -209,6 +207,15 @@ def get_result_dict(self):
209207
return result
210208

211209

210+
@dataclass
211+
class PTQRunInfo(RunInfo):
212+
def get_result_dict(self):
213+
result = super().get_result_dict()
214+
result["Num FQ"] = self.num_compress_nodes.num_fq_nodes
215+
result["Num int8"] = self.num_compress_nodes.num_int8
216+
return result
217+
218+
212219
class BaseTestPipeline(ABC):
213220
"""
214221
Base class to test compression algorithms.
@@ -286,9 +293,28 @@ def compress(self) -> None:
286293
def save_compressed_model(self) -> None:
287294
"""Save compressed model to IR."""
288295

289-
@abstractmethod
290296
def get_num_compressed(self) -> None:
291297
"""Get number of the compressed nodes in the compressed IR."""
298+
ie = ov.Core()
299+
model = ie.read_model(model=self.path_compressed_ir)
300+
301+
num_fq = 0
302+
num_int4 = 0
303+
num_int8 = 0
304+
for node in model.get_ops():
305+
node_type = node.type_info.name
306+
if node_type == "FakeQuantize":
307+
num_fq += 1
308+
309+
for i in range(node.get_output_size()):
310+
if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
311+
num_int8 += 1
312+
if node.get_output_element_type(i).get_type_name() in ["i4", "u4", "nf4"]:
313+
num_int4 += 1
314+
315+
self.run_info.num_compress_nodes.num_int8 = num_int8
316+
self.run_info.num_compress_nodes.num_int4 = num_int4
317+
self.run_info.num_compress_nodes.num_fq_nodes = num_fq
292318

293319
@abstractmethod
294320
def run_bench(self) -> None:
@@ -334,6 +360,61 @@ def run(self) -> None:
334360
self.validate()
335361
self.run_bench()
336362

363+
def collect_errors(self) -> List[ErrorReport]:
364+
"""
365+
Collects errors based on the pipeline's run information.
366+
367+
:param pipeline: The pipeline object containing run information.
368+
:return: List of error reports.
369+
"""
370+
errors = []
371+
372+
run_info = self.run_info
373+
reference_data = self.reference_data
374+
375+
metric_value = run_info.metric_value
376+
metric_reference = reference_data.get("metric_value")
377+
metric_value_fp32 = reference_data.get("metric_value_fp32")
378+
379+
if metric_value is not None and metric_value_fp32 is not None:
380+
run_info.metric_diff = round(metric_value - metric_value_fp32, 5)
381+
382+
if metric_value is not None and metric_reference is not None:
383+
atol = reference_data.get("atol", 0.001)
384+
if not np.isclose(metric_value, metric_reference, atol=atol):
385+
status_msg = (
386+
f"Regression: Metric value is less than reference {metric_value} < {metric_reference}"
387+
if metric_value < metric_reference
388+
else f"Improvement: Metric value is better than reference {metric_value} > {metric_reference}"
389+
)
390+
errors.append(ErrorReport(ErrorReason.METRICS, status_msg))
391+
392+
return errors
393+
394+
def update_status(self, error_reports: List[ErrorReport]) -> List[str]:
395+
"""
396+
Updates status of the pipeline based on the errors encountered during the run.
397+
398+
:param pipeline: The pipeline object containing run information.
399+
:param error_reports: List of errors encountered during the run.
400+
:return: List of unexpected errors.
401+
"""
402+
self.run_info.status = "" # Successful status
403+
xfails, unexpected_errors = [], []
404+
405+
for report in error_reports:
406+
xfail_reason = report.reason.value + XFAIL_SUFFIX
407+
if _is_error_xfailed(report, xfail_reason, self.reference_data):
408+
xfails.append(_get_xfail_message(report, xfail_reason, self.reference_data))
409+
else:
410+
unexpected_errors.append(report.msg)
411+
412+
if xfails:
413+
self.run_info.status = "\n".join(xfails)
414+
if unexpected_errors:
415+
self.run_info.status = "\n".join(unexpected_errors)
416+
return unexpected_errors
417+
337418

338419
class PTQTestPipeline(BaseTestPipeline):
339420
"""
@@ -421,28 +502,6 @@ def save_compressed_model(self) -> None:
421502
apply_moc_transformations(self.compressed_model, cf=True)
422503
ov.serialize(self.compressed_model, str(self.path_compressed_ir))
423504

424-
def get_num_compressed(self) -> None:
425-
ie = ov.Core()
426-
model = ie.read_model(model=self.path_compressed_ir)
427-
428-
num_fq = 0
429-
num_int4 = 0
430-
num_int8 = 0
431-
for node in model.get_ops():
432-
node_type = node.type_info.name
433-
if node_type == "FakeQuantize":
434-
num_fq += 1
435-
436-
for i in range(node.get_output_size()):
437-
if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
438-
num_int8 += 1
439-
if node.get_output_element_type(i).get_type_name() in ["i4", "u4", "nf4"]:
440-
num_int4 += 1
441-
442-
self.run_info.num_compress_nodes.num_int8 = num_int8
443-
self.run_info.num_compress_nodes.num_int4 = num_int4
444-
self.run_info.num_compress_nodes.num_fq_nodes = num_fq
445-
446505
def run_bench(self) -> None:
447506
"""
448507
Run benchmark_app to collect performance statistics.
@@ -476,3 +535,32 @@ def collect_data_from_stdout(self, stdout: str):
476535
stats = PTQTimeStats()
477536
stats.fill(stdout)
478537
self.run_info.stats_from_output = stats
538+
539+
540+
def _get_exception_type_name(report: ErrorReport) -> str:
541+
return report.msg.split("|")[0].replace("Exception Type: ", "")
542+
543+
544+
def _get_exception_error_message(report: ErrorReport) -> str:
545+
return report.msg.split("|")[1]
546+
547+
548+
def _are_exceptions_matched(report: ErrorReport, reference_exception: Dict[str, str]) -> bool:
549+
return reference_exception["error_message"] == _get_exception_error_message(report) and reference_exception[
550+
"type"
551+
] == _get_exception_type_name(report)
552+
553+
554+
def _is_error_xfailed(report: ErrorReport, xfail_reason: str, reference_data: Dict[str, Dict[str, str]]) -> bool:
555+
if xfail_reason not in reference_data:
556+
return False
557+
558+
if report.reason == ErrorReason.EXCEPTION:
559+
return _are_exceptions_matched(report, reference_data[xfail_reason])
560+
return True
561+
562+
563+
def _get_xfail_message(report: ErrorReport, xfail_reason: str, reference_data: Dict[str, Dict[str, str]]) -> str:
564+
if report.reason == ErrorReason.EXCEPTION:
565+
return f"XFAIL: {reference_data[xfail_reason]['message']} - {report.msg}"
566+
return f"XFAIL: {xfail_reason} - {report.msg}"

tests/post_training/pipelines/lm_weight_compression.py

+76-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import shutil
1515
import time
1616
from dataclasses import dataclass
17-
from typing import Dict, Optional
17+
from pathlib import Path
18+
from typing import Dict, List, Optional
1819

1920
import numpy as np
2021
import openvino as ov
@@ -30,7 +31,11 @@
3031
import nncf
3132
from tests.cross_fw.shared.paths import TEST_ROOT
3233
from tests.post_training.pipelines.base import BackendType
33-
from tests.post_training.pipelines.base import PTQTestPipeline
34+
from tests.post_training.pipelines.base import BaseTestPipeline
35+
from tests.post_training.pipelines.base import ErrorReason
36+
from tests.post_training.pipelines.base import ErrorReport
37+
from tests.post_training.pipelines.base import NumCompressNodes
38+
from tests.post_training.pipelines.base import RunInfo
3439
from tests.post_training.pipelines.base import StatsFromOutput
3540
from tools.memory_monitor import MemoryType
3641
from tools.memory_monitor import MemoryUnit
@@ -71,11 +76,53 @@ def get_stats(self) -> Dict[str, str]:
7176
return dict(zip(self.STAT_NAMES, VARS))
7277

7378

74-
class LMWeightCompression(PTQTestPipeline):
79+
@dataclass
80+
class WCRunInfo(RunInfo):
81+
def get_result_dict(self):
82+
result = super().get_result_dict()
83+
result["Num int4"] = self.num_compress_nodes.num_int4
84+
result["Num int8"] = self.num_compress_nodes.num_int8
85+
return result
86+
87+
88+
class LMWeightCompression(BaseTestPipeline):
7589
"""Pipeline for casual language models from Hugging Face repository"""
7690

7791
OV_MODEL_NAME = "openvino_model.xml"
7892

93+
def __init__(
94+
self,
95+
reported_name: str,
96+
model_id: str,
97+
backend: BackendType,
98+
compression_params: dict,
99+
output_dir: Path,
100+
data_dir: Path,
101+
reference_data: dict,
102+
no_eval: bool,
103+
run_benchmark_app: bool,
104+
torch_compile_validation: bool = False,
105+
params: dict = None,
106+
batch_size: int = 1,
107+
memory_monitor: bool = False,
108+
) -> None:
109+
super().__init__(
110+
reported_name,
111+
model_id,
112+
backend,
113+
compression_params,
114+
output_dir,
115+
data_dir,
116+
reference_data,
117+
no_eval,
118+
run_benchmark_app,
119+
torch_compile_validation,
120+
params,
121+
batch_size,
122+
memory_monitor,
123+
)
124+
self.run_info = WCRunInfo(model=reported_name, backend=self.backend, num_compress_nodes=NumCompressNodes())
125+
79126
def prepare_model(self) -> None:
80127
is_stateful = self.params.get("is_stateful", False)
81128

@@ -291,3 +338,29 @@ def _validate(self) -> None:
291338
similarity = all_metrics["similarity"][0]
292339
self.run_info.metric_name = "Similarity"
293340
self.run_info.metric_value = round(similarity, 5)
341+
342+
def collect_errors(self) -> List[ErrorReport]:
343+
errors = super().collect_errors()
344+
run_info = self.run_info
345+
reference_data = self.reference_data
346+
347+
num_int4_reference = reference_data.get("num_int4")
348+
num_int8_reference = reference_data.get("num_int8")
349+
num_int4_value = run_info.num_compress_nodes.num_int4
350+
num_int8_value = run_info.num_compress_nodes.num_int8
351+
352+
if num_int4_reference is not None and num_int4_reference != num_int4_value:
353+
status_msg = (
354+
"Regression: The number of int4 ops is different "
355+
f"than reference {num_int4_reference} != {num_int4_value}"
356+
)
357+
errors.append(ErrorReport(ErrorReason.NUM_COMPRESSED, status_msg))
358+
359+
if num_int8_reference is not None and num_int8_reference != num_int8_value:
360+
status_msg = (
361+
"Regression: The number of int8 ops is different "
362+
f"than reference {num_int8_reference} != {num_int8_value}"
363+
)
364+
errors.append(ErrorReport(ErrorReason.NUM_COMPRESSED, status_msg))
365+
366+
return errors

0 commit comments

Comments
 (0)