19
19
from datetime import timedelta
20
20
from enum import Enum
21
21
from pathlib import Path
22
- from typing import Dict , Optional
22
+ from typing import Dict , List , Optional
23
23
24
+ import numpy as np
24
25
import onnx
25
26
import openvino as ov
26
27
import torch
@@ -179,7 +180,8 @@ def format_memory_usage(memory):
179
180
return None
180
181
return int (memory )
181
182
182
- def get_result_dict (self ):
183
+ def get_result_dict (self ) -> Dict [str , str ]:
184
+ """Returns a dictionary with the results of the run."""
183
185
ram_data = {}
184
186
if self .compression_memory_usage_rss is None and self .compression_memory_usage_system is None :
185
187
ram_data ["RAM MiB" ] = self .format_memory_usage (self .compression_memory_usage )
@@ -194,10 +196,6 @@ def get_result_dict(self):
194
196
"Metric name" : self .metric_name ,
195
197
"Metric value" : self .metric_value ,
196
198
"Metric diff" : self .metric_diff ,
197
- "Num FQ" : self .num_compress_nodes .num_fq_nodes ,
198
- "Num int4" : self .num_compress_nodes .num_int4 ,
199
- "Num int8" : self .num_compress_nodes .num_int8 ,
200
- "Num sparse activations" : self .num_compress_nodes .num_sparse_activations ,
201
199
"Compr. time" : self .format_time (self .time_compression ),
202
200
** self .stats_from_output .get_stats (),
203
201
"Total time" : self .format_time (self .time_total ),
@@ -209,6 +207,15 @@ def get_result_dict(self):
209
207
return result
210
208
211
209
210
+ @dataclass
211
+ class PTQRunInfo (RunInfo ):
212
+ def get_result_dict (self ):
213
+ result = super ().get_result_dict ()
214
+ result ["Num FQ" ] = self .num_compress_nodes .num_fq_nodes
215
+ result ["Num int8" ] = self .num_compress_nodes .num_int8
216
+ return result
217
+
218
+
212
219
class BaseTestPipeline (ABC ):
213
220
"""
214
221
Base class to test compression algorithms.
@@ -286,9 +293,28 @@ def compress(self) -> None:
286
293
def save_compressed_model (self ) -> None :
287
294
"""Save compressed model to IR."""
288
295
289
- @abstractmethod
290
296
def get_num_compressed (self ) -> None :
291
297
"""Get number of the compressed nodes in the compressed IR."""
298
+ ie = ov .Core ()
299
+ model = ie .read_model (model = self .path_compressed_ir )
300
+
301
+ num_fq = 0
302
+ num_int4 = 0
303
+ num_int8 = 0
304
+ for node in model .get_ops ():
305
+ node_type = node .type_info .name
306
+ if node_type == "FakeQuantize" :
307
+ num_fq += 1
308
+
309
+ for i in range (node .get_output_size ()):
310
+ if node .get_output_element_type (i ).get_type_name () in ["i8" , "u8" ]:
311
+ num_int8 += 1
312
+ if node .get_output_element_type (i ).get_type_name () in ["i4" , "u4" , "nf4" ]:
313
+ num_int4 += 1
314
+
315
+ self .run_info .num_compress_nodes .num_int8 = num_int8
316
+ self .run_info .num_compress_nodes .num_int4 = num_int4
317
+ self .run_info .num_compress_nodes .num_fq_nodes = num_fq
292
318
293
319
@abstractmethod
294
320
def run_bench (self ) -> None :
@@ -334,6 +360,61 @@ def run(self) -> None:
334
360
self .validate ()
335
361
self .run_bench ()
336
362
363
+ def collect_errors (self ) -> List [ErrorReport ]:
364
+ """
365
+ Collects errors based on the pipeline's run information.
366
+
367
+ :param pipeline: The pipeline object containing run information.
368
+ :return: List of error reports.
369
+ """
370
+ errors = []
371
+
372
+ run_info = self .run_info
373
+ reference_data = self .reference_data
374
+
375
+ metric_value = run_info .metric_value
376
+ metric_reference = reference_data .get ("metric_value" )
377
+ metric_value_fp32 = reference_data .get ("metric_value_fp32" )
378
+
379
+ if metric_value is not None and metric_value_fp32 is not None :
380
+ run_info .metric_diff = round (metric_value - metric_value_fp32 , 5 )
381
+
382
+ if metric_value is not None and metric_reference is not None :
383
+ atol = reference_data .get ("atol" , 0.001 )
384
+ if not np .isclose (metric_value , metric_reference , atol = atol ):
385
+ status_msg = (
386
+ f"Regression: Metric value is less than reference { metric_value } < { metric_reference } "
387
+ if metric_value < metric_reference
388
+ else f"Improvement: Metric value is better than reference { metric_value } > { metric_reference } "
389
+ )
390
+ errors .append (ErrorReport (ErrorReason .METRICS , status_msg ))
391
+
392
+ return errors
393
+
394
+ def update_status (self , error_reports : List [ErrorReport ]) -> List [str ]:
395
+ """
396
+ Updates status of the pipeline based on the errors encountered during the run.
397
+
398
+ :param pipeline: The pipeline object containing run information.
399
+ :param error_reports: List of errors encountered during the run.
400
+ :return: List of unexpected errors.
401
+ """
402
+ self .run_info .status = "" # Successful status
403
+ xfails , unexpected_errors = [], []
404
+
405
+ for report in error_reports :
406
+ xfail_reason = report .reason .value + XFAIL_SUFFIX
407
+ if _is_error_xfailed (report , xfail_reason , self .reference_data ):
408
+ xfails .append (_get_xfail_message (report , xfail_reason , self .reference_data ))
409
+ else :
410
+ unexpected_errors .append (report .msg )
411
+
412
+ if xfails :
413
+ self .run_info .status = "\n " .join (xfails )
414
+ if unexpected_errors :
415
+ self .run_info .status = "\n " .join (unexpected_errors )
416
+ return unexpected_errors
417
+
337
418
338
419
class PTQTestPipeline (BaseTestPipeline ):
339
420
"""
@@ -421,28 +502,6 @@ def save_compressed_model(self) -> None:
421
502
apply_moc_transformations (self .compressed_model , cf = True )
422
503
ov .serialize (self .compressed_model , str (self .path_compressed_ir ))
423
504
424
- def get_num_compressed (self ) -> None :
425
- ie = ov .Core ()
426
- model = ie .read_model (model = self .path_compressed_ir )
427
-
428
- num_fq = 0
429
- num_int4 = 0
430
- num_int8 = 0
431
- for node in model .get_ops ():
432
- node_type = node .type_info .name
433
- if node_type == "FakeQuantize" :
434
- num_fq += 1
435
-
436
- for i in range (node .get_output_size ()):
437
- if node .get_output_element_type (i ).get_type_name () in ["i8" , "u8" ]:
438
- num_int8 += 1
439
- if node .get_output_element_type (i ).get_type_name () in ["i4" , "u4" , "nf4" ]:
440
- num_int4 += 1
441
-
442
- self .run_info .num_compress_nodes .num_int8 = num_int8
443
- self .run_info .num_compress_nodes .num_int4 = num_int4
444
- self .run_info .num_compress_nodes .num_fq_nodes = num_fq
445
-
446
505
def run_bench (self ) -> None :
447
506
"""
448
507
Run benchmark_app to collect performance statistics.
@@ -476,3 +535,32 @@ def collect_data_from_stdout(self, stdout: str):
476
535
stats = PTQTimeStats ()
477
536
stats .fill (stdout )
478
537
self .run_info .stats_from_output = stats
538
+
539
+
540
+ def _get_exception_type_name (report : ErrorReport ) -> str :
541
+ return report .msg .split ("|" )[0 ].replace ("Exception Type: " , "" )
542
+
543
+
544
+ def _get_exception_error_message (report : ErrorReport ) -> str :
545
+ return report .msg .split ("|" )[1 ]
546
+
547
+
548
+ def _are_exceptions_matched (report : ErrorReport , reference_exception : Dict [str , str ]) -> bool :
549
+ return reference_exception ["error_message" ] == _get_exception_error_message (report ) and reference_exception [
550
+ "type"
551
+ ] == _get_exception_type_name (report )
552
+
553
+
554
+ def _is_error_xfailed (report : ErrorReport , xfail_reason : str , reference_data : Dict [str , Dict [str , str ]]) -> bool :
555
+ if xfail_reason not in reference_data :
556
+ return False
557
+
558
+ if report .reason == ErrorReason .EXCEPTION :
559
+ return _are_exceptions_matched (report , reference_data [xfail_reason ])
560
+ return True
561
+
562
+
563
+ def _get_xfail_message (report : ErrorReport , xfail_reason : str , reference_data : Dict [str , Dict [str , str ]]) -> str :
564
+ if report .reason == ErrorReason .EXCEPTION :
565
+ return f"XFAIL: { reference_data [xfail_reason ]['message' ]} - { report .msg } "
566
+ return f"XFAIL: { xfail_reason } - { report .msg } "
0 commit comments