Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TESTS] Make a common entry point for conformance tests #3265

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions tests/post_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,14 @@ To mark a test as expected to fail (xfail) when a number of compression operatio
...
num_compressed_xfail_reason: "Issue-<jira ticket number>"
```

To mark a test as expected to fail (xfail) during the compression process with an exception:

```yml
<Name from model scopes>_backend_<BACKEND>:
...
exception_xfail_reason:
type: "<ExceptionType>", e.g. TypeError
error_message: "<Error message from Exception>"
message: "Issue-<jira ticket number>"
```
5 changes: 4 additions & 1 deletion tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ hf/hf-internal-testing/tiny-random-GPTNeoXForCausalLM_statefull_backend_OPTIMUM:
metric_value: null
hf/hf-internal-testing/tiny-random-GPTNeoXForCausalLM_stateless_backend_OPTIMUM:
metric_value: null
xfail_reason: "Issue-161969"
exception_xfail_reason:
type: "TypeError"
error_message: "cannot pickle 'openvino._pyopenvino.Tensor' object"
message: "Issue-161969"
hf/hf-internal-testing/tiny-random-gpt2_backend_FP32:
metric_value: null
hf/hf-internal-testing/tiny-random-gpt2_backend_OPTIMUM:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Dict, List

import nncf
from nncf.experimental.torch.sparsify_activations import TargetScope
from nncf.parameters import CompressWeightsMode
from tests.post_training.experimental.sparsify_activations.pipelines import ImageClassificationTimmSparsifyActivations
from tests.post_training.experimental.sparsify_activations.pipelines import LMSparsifyActivations
from tests.post_training.model_scope import generate_tests_scope
from tests.post_training.pipelines.base import BackendType

SPARSIFY_ACTIVATIONS_MODELS = [
Expand All @@ -30,6 +27,7 @@
{
"reported_name": "tinyllama_ffn_sparse20",
"model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
"model_name": "tinyllama",
"pipeline_cls": LMSparsifyActivations,
"compression_params": {
"compress_weights": None,
Expand All @@ -45,6 +43,7 @@
{
"reported_name": "tinyllama_int8_asym_data_free_ffn_sparse20",
"model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
"model_name": "tinyllama",
"pipeline_cls": LMSparsifyActivations,
"compression_params": {
"compress_weights": {
Expand All @@ -62,6 +61,7 @@
{
"reported_name": "timm/deit3_small_patch16_224",
"model_id": "deit3_small_patch16_224",
"model_name": "timm/deit3_small_patch16_224",
"pipeline_cls": ImageClassificationTimmSparsifyActivations,
"compression_params": {},
"backends": [BackendType.FP32],
Expand All @@ -70,6 +70,7 @@
{
"reported_name": "timm/deit3_small_patch16_224_qkv_sparse20_fc1_sparse20_fc2_sparse30",
"model_id": "deit3_small_patch16_224",
"model_name": "timm/deit3_small_patch16_224",
"pipeline_cls": ImageClassificationTimmSparsifyActivations,
"compression_params": {
"sparsify_activations": {
Expand All @@ -85,34 +86,4 @@
]


def generate_tests_scope(models_list: List[Dict]) -> Dict[str, Dict]:
"""
Generate tests by names "{reported_name}_backend_{backend}"
"""
tests_scope = {}
fp32_models = set()
for test_model_param in models_list:
model_id = test_model_param["model_id"]
reported_name = test_model_param["reported_name"]

for backend in test_model_param["backends"]:
model_param = copy.deepcopy(test_model_param)
if "is_batch_size_supported" not in model_param: # Set default value of is_batch_size_supported.
model_param["is_batch_size_supported"] = True
test_case_name = f"{reported_name}_backend_{backend.value}"
model_param["backend"] = backend
model_param.pop("backends")
if backend == BackendType.FP32:
if model_id in fp32_models:
msg = f"Duplicate test case for {model_id} with FP32 backend"
raise nncf.ValidationError(msg)
fp32_models.add(model_id)
if test_case_name in tests_scope:
msg = f"{test_case_name} already in tests_scope"
raise nncf.ValidationError(msg)
tests_scope[test_case_name] = model_param

return tests_scope


SPARSIFY_ACTIVATIONS_TEST_CASES = generate_tests_scope(SPARSIFY_ACTIVATIONS_MODELS)
125 changes: 8 additions & 117 deletions tests/post_training/experimental/sparsify_activations/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@


from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
Expand All @@ -32,13 +30,9 @@
from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend
from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor
from tests.post_training.pipelines.base import LIMIT_LENGTH_OF_STATUS
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.base import ErrorReason
from tests.post_training.pipelines.base import ErrorReport
from tests.post_training.pipelines.base import NumCompressNodes
from tests.post_training.pipelines.base import RunInfo
from tests.post_training.pipelines.base import PTQTestPipeline
from tests.post_training.pipelines.image_classification_timm import ImageClassificationTimm
from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
from tests.post_training.pipelines.lm_weight_compression import WCTimeStats
Expand All @@ -58,94 +52,11 @@ class SATimeStats(WCTimeStats):
REGEX_PREFIX = [*WCTimeStats.REGEX_PREFIX, SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC]


@dataclass
class SANumCompressNodes(NumCompressNodes):
num_sparse_activations: Optional[int] = None


@dataclass
class SARunInfo(RunInfo):
num_compress_nodes: SANumCompressNodes = field(default_factory=SANumCompressNodes)

def get_result_dict(self):
return {
"Model": self.model,
"Backend": self.backend.value if self.backend else None,
"Metric name": self.metric_name,
"Metric value": self.metric_value,
"Metric diff": self.metric_diff,
"Num FQ": self.num_compress_nodes.num_fq_nodes,
"Num int4": self.num_compress_nodes.num_int4,
"Num int8": self.num_compress_nodes.num_int8,
"Num sparse activations": self.num_compress_nodes.num_sparse_activations,
"RAM MiB": self.format_memory_usage(self.compression_memory_usage),
"Compr. time": self.format_time(self.time_compression),
**self.stats_from_output.get_stats(),
"Total time": self.format_time(self.time_total),
"FPS": self.fps,
"Status": self.status[:LIMIT_LENGTH_OF_STATUS] if self.status is not None else None,
}


class SAPipelineMixin:
class SAPipelineMixin(PTQTestPipeline):
"""
Common methods in the test pipeline for Sparsify Activations.
"""

def __init__(
self,
reported_name: str,
model_id: str,
backend: BackendType,
compression_params: dict,
output_dir: Path,
data_dir: Path,
reference_data: dict,
no_eval: bool,
run_benchmark_app: bool,
params: dict = None,
batch_size: int = 1,
):
super().__init__(
reported_name=reported_name,
model_id=model_id,
backend=backend,
compression_params=compression_params,
output_dir=output_dir,
data_dir=data_dir,
reference_data=reference_data,
no_eval=no_eval,
run_benchmark_app=run_benchmark_app,
params=params,
batch_size=batch_size,
)
self.run_info = SARunInfo(model=reported_name, backend=backend)

@staticmethod
def count_compressed_nodes_from_ir(model: ov.Model) -> SANumCompressNodes:
"""
Get number of compressed nodes in the compressed IR.
"""
num_fq_nodes = 0
num_int8 = 0
num_int4 = 0
for node in model.get_ops():
if node.type_info.name == "FakeQuantize":
num_fq_nodes += 1
for i in range(node.get_output_size()):
if node.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
num_int8 += 1
if node.get_output_element_type(i).get_type_name() in ["i4", "u4", "nf4"]:
num_int4 += 1

num_sparse_activations = count_sparsifier_patterns_in_ov(model)
return SANumCompressNodes(
num_fq_nodes=num_fq_nodes,
num_int8=num_int8,
num_int4=num_int4,
num_sparse_activations=num_sparse_activations,
)

def collect_data_from_stdout(self, stdout: str):
stats = SATimeStats()
stats.fill(stdout)
Expand All @@ -171,15 +82,11 @@ def _compress(self):
**self.compression_params["sparsify_activations"],
)

def _validate(self):
errors = super()._validate()
ref_num_sparse_activations = self.reference_data.get("num_sparse_activations", 0)
num_sparse_activations = self.run_info.num_compress_nodes.num_sparse_activations
if num_sparse_activations != ref_num_sparse_activations:
status_msg = f"Regression: The number of sparse activations is {num_sparse_activations}, \
which differs from reference {ref_num_sparse_activations}."
errors.append(ErrorReport(ErrorReason.NUM_COMPRESSED, status_msg))
return errors
def get_num_compressed(self) -> None:
super().get_num_compressed()
ie = ov.Core()
model = ie.read_model(model=self.path_compressed_ir)
self.run_info.num_compress_nodes.num_sparse_activations = count_sparsifier_patterns_in_ov(model)


class LMSparsifyActivations(SAPipelineMixin, LMWeightCompression):
Expand Down Expand Up @@ -268,6 +175,7 @@ def prepare_calibration_dataset(self):
self.calibration_dataset = nncf.Dataset(chunks, self.get_transform_calibration_fn())

def save_compressed_model(self):
self.path_compressed_ir = self.output_model_dir / self.OV_MODEL_NAME
if self.backend == BackendType.CUDA_TORCH:
self.model_hf.float()
for module in self.model_hf.nncf.modules():
Expand All @@ -279,16 +187,6 @@ def save_compressed_model(self):
else:
super().save_compressed_model()

def get_num_compressed(self):
"""
Get number of quantization ops and sparsifier ops in the compressed IR.
"""
if self.backend in PT_BACKENDS:
model = ov.Core().read_model(self.output_model_dir / self.OV_MODEL_NAME)
else:
model = self.model
self.run_info.num_compress_nodes = self.count_compressed_nodes_from_ir(model)

def _dump_model_fp32(self):
if self.backend == BackendType.CUDA_TORCH:
export_from_model(
Expand Down Expand Up @@ -318,10 +216,3 @@ def prepare_calibration_dataset(self):
subset = torch.utils.data.Subset(val_dataset, indices=indices)
loader = torch.utils.data.DataLoader(subset, batch_size=self.batch_size, num_workers=2, shuffle=False)
self.calibration_dataset = nncf.Dataset(loader, self.get_transform_calibration_fn())

def get_num_compressed(self):
"""
Get number of quantization ops and sparsifier ops in the compressed IR.
"""
model = ov.Core().read_model(model=self.path_compressed_ir)
self.run_info.num_compress_nodes = self.count_compressed_nodes_from_ir(model)
Loading