9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
11
12
-
13
- import time
14
- import traceback
15
12
from collections import OrderedDict
16
13
from pathlib import Path
17
14
from typing import Dict , Optional
23
20
from tests .post_training .experimental .sparsify_activations .model_scope import SPARSIFY_ACTIVATIONS_TEST_CASES
24
21
from tests .post_training .experimental .sparsify_activations .pipelines import SARunInfo
25
22
from tests .post_training .pipelines .base import BackendType
26
- from tests .post_training .pipelines .base import BaseTestPipeline
27
- from tests .post_training .test_quantize_conformance import create_short_run_info
28
23
from tests .post_training .test_quantize_conformance import fixture_batch_size # noqa: F401
29
24
from tests .post_training .test_quantize_conformance import fixture_data # noqa: F401
30
25
from tests .post_training .test_quantize_conformance import fixture_extra_columns # noqa: F401
34
29
from tests .post_training .test_quantize_conformance import fixture_run_fp32_backend # noqa: F401
35
30
from tests .post_training .test_quantize_conformance import fixture_run_torch_cuda_backend # noqa: F401
36
31
from tests .post_training .test_quantize_conformance import fixture_subset_size # noqa: F401
37
- from tests .post_training .test_quantize_conformance import maybe_skip_test_case
38
- from tests .post_training .test_quantize_conformance import write_logs
32
+ from tests .post_training .test_quantize_conformance import run_pipeline
39
33
40
34
41
35
@pytest .fixture (scope = "session" , name = "sparsify_activations_reference_data" )
@@ -59,39 +53,6 @@ def fixture_sparsify_activations_report_data(output_dir):
59
53
df .to_csv (output_dir / "results.csv" , index = False )
60
54
61
55
62
- def create_pipeline_kwargs (
63
- test_model_param : Dict ,
64
- subset_size ,
65
- test_case_name : str ,
66
- reference_data : Dict [str , Dict ],
67
- fp32_model_params : Dict [str , Dict ],
68
- ):
69
- if subset_size :
70
- if "compression_params" not in test_model_param :
71
- test_model_param ["compression_params" ] = {}
72
- test_model_param ["compression_params" ]["subset_size" ] = subset_size
73
-
74
- print ("\n " )
75
- print (f"Model: { test_model_param ['reported_name' ]} " )
76
- print (f"Backend: { test_model_param ['backend' ]} " )
77
- print (f"Comprssion params: { test_model_param ['compression_params' ]} " )
78
-
79
- # Get target fp32 metric value
80
- model_id = test_model_param ["model_id" ]
81
- fp32_test_case_name = fp32_model_params [model_id ]["reported_name" ] + f"_backend_{ BackendType .FP32 .value } "
82
- test_reference = reference_data [test_case_name ]
83
- test_reference ["metric_value_fp32" ] = reference_data [fp32_test_case_name ]["metric_value" ]
84
-
85
- return {
86
- "reported_name" : test_model_param ["reported_name" ],
87
- "model_id" : test_model_param ["model_id" ],
88
- "backend" : test_model_param ["backend" ],
89
- "compression_params" : test_model_param ["compression_params" ],
90
- "params" : test_model_param .get ("params" ),
91
- "reference_data" : test_reference ,
92
- }
93
-
94
-
95
56
@pytest .mark .parametrize ("test_case_name" , SPARSIFY_ACTIVATIONS_TEST_CASES .keys ())
96
57
def test_sparsify_activations (
97
58
sparsify_activations_reference_data : dict ,
@@ -108,55 +69,26 @@ def test_sparsify_activations(
108
69
capsys : pytest .CaptureFixture ,
109
70
extra_columns : bool ,
110
71
):
111
- pipeline = None
112
- err_msg = None
113
- test_model_param = None
114
- start_time = time .perf_counter ()
115
- try :
116
- if test_case_name not in sparsify_activations_reference_data :
117
- msg = f"{ test_case_name } is not defined in `sparsify_activations_reference_data` fixture"
118
- raise RuntimeError (msg )
119
- test_model_param = SPARSIFY_ACTIVATIONS_TEST_CASES [test_case_name ]
120
- maybe_skip_test_case (test_model_param , run_fp32_backend , run_torch_cuda_backend , batch_size )
121
- fp32_model_params = {
122
- tc ["model_id" ]: tc for tc in SPARSIFY_ACTIVATIONS_TEST_CASES .values () if tc ["backend" ] == BackendType .FP32
123
- }
124
- pipeline_cls = test_model_param ["pipeline_cls" ]
125
- pipeline_kwargs = create_pipeline_kwargs (
126
- test_model_param , subset_size , test_case_name , sparsify_activations_reference_data , fp32_model_params
127
- )
128
- calibration_batch_size = batch_size or test_model_param .get ("batch_size" , 1 )
129
- pipeline_kwargs .update (
130
- {
131
- "output_dir" : output_dir ,
132
- "data_dir" : data_dir ,
133
- "no_eval" : no_eval ,
134
- "run_benchmark_app" : run_benchmark_app ,
135
- "batch_size" : calibration_batch_size ,
136
- }
137
- )
138
- pipeline : BaseTestPipeline = pipeline_cls (** pipeline_kwargs )
139
- pipeline .run ()
140
- except Exception as e :
141
- err_msg = str (e )
142
- traceback .print_exc ()
143
-
144
- if pipeline is not None :
145
- pipeline .cleanup_cache ()
146
- run_info = pipeline .run_info
147
- if err_msg :
148
- run_info .status = f"{ run_info .status } | { err_msg } " if run_info .status else err_msg
149
-
150
- captured = capsys .readouterr ()
151
- write_logs (captured , pipeline )
152
-
153
- if extra_columns :
154
- pipeline .collect_data_from_stdout (captured .out )
155
- else :
156
- run_info = create_short_run_info (test_model_param , err_msg , test_case_name )
157
-
158
- run_info .time_total = time .perf_counter () - start_time
159
- sparsify_activations_result_data [test_case_name ] = run_info
160
-
161
- if err_msg :
162
- pytest .fail (err_msg )
72
+ fp32_model_params = {
73
+ tc ["model_id" ]: tc for tc in SPARSIFY_ACTIVATIONS_TEST_CASES .values () if tc ["backend" ] == BackendType .FP32
74
+ }
75
+ run_pipeline (
76
+ test_case_name ,
77
+ sparsify_activations_reference_data ,
78
+ SPARSIFY_ACTIVATIONS_TEST_CASES ,
79
+ sparsify_activations_result_data ,
80
+ output_dir ,
81
+ data_dir ,
82
+ no_eval ,
83
+ batch_size ,
84
+ run_fp32_backend ,
85
+ run_torch_cuda_backend ,
86
+ subset_size ,
87
+ run_benchmark_app ,
88
+ False , # torch_compile_validation is not used in SA
89
+ capsys ,
90
+ extra_columns ,
91
+ False , # memory_monitor is not used in SA
92
+ None , # use_avx2 is not used in SA
93
+ fp32_model_params ,
94
+ )
0 commit comments