|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
12 |
| -import copy |
13 |
| -from typing import Dict, List |
14 |
| - |
15 |
| -import nncf |
16 | 12 | from nncf.experimental.torch.sparsify_activations import TargetScope
|
17 | 13 | from nncf.parameters import CompressWeightsMode
|
18 | 14 | from tests.post_training.experimental.sparsify_activations.pipelines import ImageClassificationTimmSparsifyActivations
|
19 | 15 | from tests.post_training.experimental.sparsify_activations.pipelines import LMSparsifyActivations
|
| 16 | +from tests.post_training.model_scope import generate_tests_scope |
20 | 17 | from tests.post_training.pipelines.base import BackendType
|
21 | 18 |
|
22 | 19 | SPARSIFY_ACTIVATIONS_MODELS = [
|
|
30 | 27 | {
|
31 | 28 | "reported_name": "tinyllama_ffn_sparse20",
|
32 | 29 | "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
|
| 30 | + "model_name": "tinyllama", |
33 | 31 | "pipeline_cls": LMSparsifyActivations,
|
34 | 32 | "compression_params": {
|
35 | 33 | "compress_weights": None,
|
|
45 | 43 | {
|
46 | 44 | "reported_name": "tinyllama_int8_asym_data_free_ffn_sparse20",
|
47 | 45 | "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
|
| 46 | + "model_name": "tinyllama", |
48 | 47 | "pipeline_cls": LMSparsifyActivations,
|
49 | 48 | "compression_params": {
|
50 | 49 | "compress_weights": {
|
|
62 | 61 | {
|
63 | 62 | "reported_name": "timm/deit3_small_patch16_224",
|
64 | 63 | "model_id": "deit3_small_patch16_224",
|
| 64 | + "model_name": "timm/deit3_small_patch16_224", |
65 | 65 | "pipeline_cls": ImageClassificationTimmSparsifyActivations,
|
66 | 66 | "compression_params": {},
|
67 | 67 | "backends": [BackendType.FP32],
|
|
70 | 70 | {
|
71 | 71 | "reported_name": "timm/deit3_small_patch16_224_qkv_sparse20_fc1_sparse20_fc2_sparse30",
|
72 | 72 | "model_id": "deit3_small_patch16_224",
|
| 73 | + "model_name": "timm/deit3_small_patch16_224", |
73 | 74 | "pipeline_cls": ImageClassificationTimmSparsifyActivations,
|
74 | 75 | "compression_params": {
|
75 | 76 | "sparsify_activations": {
|
|
85 | 86 | ]
|
86 | 87 |
|
87 | 88 |
|
88 |
| -def generate_tests_scope(models_list: List[Dict]) -> Dict[str, Dict]: |
89 |
| - """ |
90 |
| - Generate tests by names "{reported_name}_backend_{backend}" |
91 |
| - """ |
92 |
| - tests_scope = {} |
93 |
| - fp32_models = set() |
94 |
| - for test_model_param in models_list: |
95 |
| - model_id = test_model_param["model_id"] |
96 |
| - reported_name = test_model_param["reported_name"] |
97 |
| - |
98 |
| - for backend in test_model_param["backends"]: |
99 |
| - model_param = copy.deepcopy(test_model_param) |
100 |
| - if "is_batch_size_supported" not in model_param: # Set default value of is_batch_size_supported. |
101 |
| - model_param["is_batch_size_supported"] = True |
102 |
| - test_case_name = f"{reported_name}_backend_{backend.value}" |
103 |
| - model_param["backend"] = backend |
104 |
| - model_param.pop("backends") |
105 |
| - if backend == BackendType.FP32: |
106 |
| - if model_id in fp32_models: |
107 |
| - msg = f"Duplicate test case for {model_id} with FP32 backend" |
108 |
| - raise nncf.ValidationError(msg) |
109 |
| - fp32_models.add(model_id) |
110 |
| - if test_case_name in tests_scope: |
111 |
| - msg = f"{test_case_name} already in tests_scope" |
112 |
| - raise nncf.ValidationError(msg) |
113 |
| - tests_scope[test_case_name] = model_param |
114 |
| - |
115 |
| - return tests_scope |
116 |
| - |
117 |
| - |
118 | 89 | SPARSIFY_ACTIVATIONS_TEST_CASES = generate_tests_scope(SPARSIFY_ACTIVATIONS_MODELS)
|
0 commit comments