19
19
from datetime import timedelta
20
20
from enum import Enum
21
21
from pathlib import Path
22
- from typing import Dict , Optional
22
+ from typing import Dict , List , Optional
23
23
24
24
import numpy as np
25
25
import onnx
36
36
from tools .memory_monitor import memory_monitor_context
37
37
38
38
DEFAULT_VAL_THREADS = 4
39
- METRICS_XFAIL_REASON = "metrics_xfail_reason"
39
+ XFAIL_SUFFIX = "_xfail_reason"
40
+
41
+
42
+ class ErrorReason (Enum ):
43
+ METRICS = "metrics"
44
+ NUM_COMPRESSED = "num_compressed"
45
+
46
+
47
+ @dataclass
48
+ class ErrorReport :
49
+ reason : ErrorReason
50
+ msg : str
40
51
41
52
42
53
class BackendType (Enum ):
@@ -278,9 +289,31 @@ def get_num_compressed(self) -> None:
278
289
def run_bench (self ) -> None :
279
290
"""Run a benchmark to collect performance statistics."""
280
291
281
- @abstractmethod
282
- def _validate (self ) -> None :
283
- """Validate IR."""
292
+ def _validate (self ) -> List [ErrorReport ]:
293
+ """
294
+ Validates some test criteria.
295
+ returns:
296
+ A list of error reports generated during validation.
297
+ """
298
+ return []
299
+
300
+ def _process_errors (self , errors ) -> str :
301
+ """
302
+ Processes a list of error reports and updates the run status.
303
+
304
+ :param errors: A list of error reports.
305
+ :return: A string representing the concatenated statuses of the processed errors.
306
+ """
307
+ xfails , msg_list = [], []
308
+ for report in errors :
309
+ xfail_reason = report .reason .value + XFAIL_SUFFIX
310
+ if xfail_reason in self .reference_data :
311
+ xfails .append (f"XFAIL: { self .reference_data [xfail_reason ]} - { report .msg } " )
312
+ else :
313
+ msg_list .append (report .msg )
314
+ if msg_list :
315
+ raise ValueError ("\n " .join (msg_list ))
316
+ self .run_info .status = "\n " .join (xfails )
284
317
285
318
def prepare (self ):
286
319
"""
@@ -302,7 +335,7 @@ def validate(self) -> None:
302
335
return
303
336
print ("Validation..." )
304
337
305
- self ._validate ()
338
+ errors = self ._validate ()
306
339
307
340
metric_value = self .run_info .metric_value
308
341
metric_reference = self .reference_data .get ("metric_value" )
@@ -311,22 +344,19 @@ def validate(self) -> None:
311
344
if metric_value is not None and metric_value_fp32 is not None :
312
345
self .run_info .metric_diff = round (self .run_info .metric_value - self .reference_data ["metric_value_fp32" ], 5 )
313
346
314
- status_msg = None
315
347
if (
316
348
metric_value is not None
317
349
and metric_reference is not None
318
350
and not np .isclose (metric_value , metric_reference , atol = self .reference_data .get ("atol" , 0.001 ))
319
351
):
352
+ status_msg = None
320
353
if metric_value < metric_reference :
321
354
status_msg = f"Regression: Metric value is less than reference { metric_value } < { metric_reference } "
322
355
if metric_value > metric_reference :
323
356
status_msg = f"Improvement: Metric value is better than reference { metric_value } > { metric_reference } "
324
-
325
- if status_msg is not None :
326
- if METRICS_XFAIL_REASON in self .reference_data :
327
- self .run_info .status = f"XFAIL: { self .reference_data [METRICS_XFAIL_REASON ]} - { status_msg } "
328
- else :
329
- raise ValueError (status_msg )
357
+ if status_msg :
358
+ errors .append (ErrorReport (ErrorReason .METRICS , status_msg ))
359
+ self ._process_errors (errors )
330
360
331
361
def run (self ) -> None :
332
362
"""
0 commit comments