|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | import os
|
5 |
| -from enum import Enum |
6 |
| -from typing import Any, Dict, List, Tuple |
| 5 | +import random |
| 6 | +from pathlib import Path |
| 7 | +from time import time |
| 8 | +from typing import Dict, List, Tuple |
7 | 9 |
|
8 | 10 | import numpy as np
|
9 | 11 | import openvino as ov
|
| 12 | +import pandas as pd |
10 | 13 | import pytest
|
| 14 | +from tqdm import tqdm |
11 | 15 |
|
12 | 16 | from openvino_xai import Task
|
| 17 | +from openvino_xai.common.parameters import ( |
| 18 | + BlackBoxXAIMethods, |
| 19 | + Method, |
| 20 | + Task, |
| 21 | + WhiteBoxXAIMethods, |
| 22 | +) |
13 | 23 | from openvino_xai.common.utils import retrieve_otx_model
|
14 | 24 | from openvino_xai.explainer.explainer import Explainer, ExplainMode
|
| 25 | +from openvino_xai.explainer.explanation import Explanation |
15 | 26 | from openvino_xai.explainer.utils import (
|
16 | 27 | ActivationType,
|
17 | 28 | get_postprocess_fn,
|
18 | 29 | get_preprocess_fn,
|
19 | 30 | )
|
| 31 | +from openvino_xai.methods.black_box.base import Preset |
20 | 32 | from openvino_xai.metrics import ADCC, InsertionDeletionAUC, PointingGame
|
21 |
| -from tests.unit.explanation.test_explanation_utils import VOC_NAMES |
| 33 | +from tests.perf.perf_tests_utils import convert_timm_to_ir |
| 34 | +from tests.test_suite.custom_dataset import CustomVOCDetection |
| 35 | +from tests.test_suite.dataset_utils import ( |
| 36 | + DatasetType, |
| 37 | + coco_anns_to_gt_bboxes, |
| 38 | + define_dataset_type, |
| 39 | + voc_anns_to_gt_bboxes, |
| 40 | +) |
| 41 | +from tests.unit.explainer.test_explanation_utils import VOC_NAMES, get_imagenet_labels |
22 | 42 |
|
23 | 43 | datasets = pytest.importorskip("torchvision.datasets")
|
| 44 | +timm = pytest.importorskip("timm") |
| 45 | +torch = pytest.importorskip("torch") |
| 46 | + |
| 47 | + |
| 48 | +IMAGENET_MODELS = [ |
| 49 | + "resnet18.a1_in1k", |
| 50 | + # "resnet50.a1_in1k", |
| 51 | + # "resnext50_32x4d.a1h_in1k", |
| 52 | + # "vgg16.tv_in1k" |
| 53 | +] |
| 54 | +VOC_MODELS = [ |
| 55 | + # "mlc_mobilenetv3_large_voc" |
| 56 | +] |
| 57 | +TRANSFORMER_MODELS = [ |
| 58 | + "deit_tiny_patch16_224.fb_in1k", # Downloads last month 8,377 |
| 59 | + # "deit_base_patch16_224.fb_in1k", # Downloads last month 6,323 |
| 60 | + # "vit_tiny_patch16_224.augreg_in21k", # Downloads last month 3,671 - trained on ImageNet-21k |
| 61 | + # "vit_base_patch16_224.augreg2_in21k_ft_in1k", # Downloads last month 207,590 - trained on ImageNet-21k |
| 62 | +] |
| 63 | + |
| 64 | +TEST_MODELS = IMAGENET_MODELS + VOC_MODELS + TRANSFORMER_MODELS |
| 65 | +IMAGENET_LABELS = get_imagenet_labels() |
| 66 | +EXPLAIN_METHODS = [Method.RECIPROCAM, Method.AISE, Method.RISE, Method.ACTIVATIONMAP] |
24 | 67 |
|
25 | 68 |
|
26 |
| -class DatasetType(Enum): |
27 |
| - COCO = "coco" |
28 |
| - VOC = "voc" |
29 |
| - |
30 |
| - |
31 |
| -def coco_anns_to_gt_bboxes( |
32 |
| - anns: List[Dict[str, Any]] | Dict[str, Any], coco_val_labels: Dict[int, str] |
33 |
| -) -> Dict[str, List[Tuple[int, int, int, int]]]: |
34 |
| - gt_bboxes = {} |
35 |
| - for ann in anns: |
36 |
| - category_id = ann["category_id"] |
37 |
| - category_name = coco_val_labels[category_id] |
38 |
| - bbox = ann["bbox"] |
39 |
| - if category_name not in gt_bboxes: |
40 |
| - gt_bboxes[category_name] = [] |
41 |
| - gt_bboxes[category_name].append(bbox) |
42 |
| - return gt_bboxes |
43 |
| - |
44 |
| - |
45 |
| -def voc_anns_to_gt_bboxes( |
46 |
| - anns: List[Dict[str, Any]] | Dict[str, Any], *args: Any |
47 |
| -) -> Dict[str, List[Tuple[int, int, int, int]]]: |
48 |
| - gt_bboxes = {} |
49 |
| - anns = anns["annotation"]["object"] |
50 |
| - for ann in anns: |
51 |
| - category_name = ann["name"] |
52 |
| - bndbox = list(map(float, ann["bndbox"].values())) |
53 |
| - bndbox = np.array(bndbox, dtype=np.int32) |
54 |
| - x_min, y_min, x_max, y_max = bndbox |
55 |
| - bbox = (x_min, y_min, x_max - x_min, y_max - y_min) |
56 |
| - |
57 |
| - if category_name not in gt_bboxes: |
58 |
| - gt_bboxes[category_name] = [] |
59 |
| - gt_bboxes[category_name].append(bbox) |
60 |
| - return gt_bboxes |
61 |
| - |
62 |
| - |
63 |
| -def define_dataset_type(data_root: str, ann_path: str) -> DatasetType: |
64 |
| - if data_root and ann_path and ann_path.lower().endswith(".json"): |
65 |
| - if any(image_name.endswith(".jpg") for image_name in os.listdir(data_root)): |
66 |
| - return DatasetType.COCO |
67 |
| - |
68 |
| - required_voc_dirs = {"JPEGImages", "SegmentationObject", "ImageSets", "Annotations", "SegmentationClass"} |
69 |
| - for _, dir, _ in os.walk(data_root): |
70 |
| - if required_voc_dirs.issubset(set(dir)): |
71 |
| - return DatasetType.VOC |
72 |
| - |
73 |
| - raise ValueError("Dataset type is not supported") |
74 |
| - |
75 |
| - |
76 |
| -@pytest.mark.parametrize( |
77 |
| - "data_root, ann_path", |
78 |
| - [ |
79 |
| - ("tests/assets/cheetah_coco/images/val", "tests/assets/cheetah_coco/annotations/instances_val.json"), |
80 |
| - ("tests/assets/cheetah_voc", None), |
81 |
| - ], |
82 |
| -) |
83 | 69 | class TestAccuracy:
|
84 |
| - MODEL_NAME = "mlc_mobilenetv3_large_voc" |
85 |
| - |
86 |
| - @pytest.fixture(autouse=True) |
87 |
| - def setup(self, fxt_data_root, data_root, ann_path): |
88 |
| - data_dir = fxt_data_root |
89 |
| - retrieve_otx_model(data_dir, self.MODEL_NAME) |
90 |
| - model_path = data_dir / "otx_models" / (self.MODEL_NAME + ".xml") |
91 |
| - model = ov.Core().read_model(model_path) |
92 |
| - |
93 |
| - self.setup_dataset(data_root, ann_path) |
94 |
| - |
95 |
| - self.preprocess_fn = get_preprocess_fn( |
96 |
| - change_channel_order=self.channel_format == "BGR", |
97 |
| - input_size=(224, 224), |
98 |
| - hwc_to_chw=True, |
99 |
| - ) |
100 |
| - self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID) |
101 |
| - |
102 |
| - self.explainer = Explainer( |
103 |
| - model=model, |
104 |
| - task=Task.CLASSIFICATION, |
105 |
| - preprocess_fn=self.preprocess_fn, |
106 |
| - explain_mode=ExplainMode.WHITEBOX, |
107 |
| - ) |
108 |
| - |
109 |
| - self.pointing_game = PointingGame() |
110 |
| - self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn) |
111 |
| - self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer) |
| 70 | + def setup_dataset(self, dataset_parameters: List[Tuple[Path, Path | None]]): |
| 71 | + if dataset_parameters == (None, None): |
| 72 | + data_root, ann_path = Path("tests/assets/cheetah_voc"), None |
| 73 | + else: |
| 74 | + data_root, ann_path = dataset_parameters |
112 | 75 |
|
113 |
| - def setup_dataset(self, data_root: str, ann_path: str): |
114 | 76 | self.dataset_type = define_dataset_type(data_root, ann_path)
|
115 |
| - self.channel_format = "RGB" if self.dataset_type in [DatasetType.VOC, DatasetType.COCO] else "None" |
116 |
| - |
117 | 77 | if self.dataset_type == DatasetType.COCO:
|
118 | 78 | self.dataset = datasets.CocoDetection(root=data_root, annFile=ann_path)
|
119 | 79 | self.dataset_labels_dict = {cats["id"]: cats["name"] for cats in self.dataset.coco.cats.values()}
|
120 | 80 | self.anns_to_gt_bboxes = coco_anns_to_gt_bboxes
|
121 |
| - elif self.dataset_type == DatasetType.VOC: |
122 |
| - self.dataset = datasets.VOCDetection(root=data_root, download=False, year="2012", image_set="val") |
| 81 | + elif self.dataset_type in [DatasetType.VOC, DatasetType.ILSVRC]: |
| 82 | + self.dataset = CustomVOCDetection(root=data_root, download=False, year="2012", image_set="val") |
123 | 83 | self.dataset_labels_dict = None
|
124 | 84 | self.anns_to_gt_bboxes = voc_anns_to_gt_bboxes
|
| 85 | + self.dataset = self.subset_dataset(num_samples=5000, seed=42) |
| 86 | + |
| 87 | + def subset_dataset(self, num_samples=-1, seed=42): |
| 88 | + if num_samples == -1 or num_samples >= len(self.dataset): |
| 89 | + return self.dataset |
| 90 | + random.seed(seed) |
| 91 | + subset_indices = random.sample(range(len(self.dataset)), num_samples) |
| 92 | + return torch.utils.data.Subset(self.dataset, subset_indices) |
| 93 | + |
| 94 | + def setup_model(self, data_dir, model_name): |
| 95 | + if model_name in VOC_MODELS: |
| 96 | + self.dataset_label_list = VOC_NAMES |
| 97 | + retrieve_otx_model(data_dir, model_name) |
| 98 | + model_path = data_dir / "otx_models" / (model_name + ".xml") |
| 99 | + model = ov.Core().read_model(model_path) |
| 100 | + return model, None |
| 101 | + |
| 102 | + elif model_name in IMAGENET_MODELS + TRANSFORMER_MODELS: |
| 103 | + self.dataset_label_list = IMAGENET_LABELS |
| 104 | + _, model_cfg = convert_timm_to_ir(model_name, data_dir, self.supported_num_classes) |
| 105 | + ir_path = data_dir / "timm_models" / "converted_models" / model_name / "model_fp32.xml" |
| 106 | + model = ov.Core().read_model(ir_path) |
| 107 | + return model, model_cfg |
| 108 | + else: |
| 109 | + raise ValueError(f"Model {model_name} is not supported since it's not VOC or ImageNet model.") |
| 110 | + |
| 111 | + def setup_process_fn(self, model_cfg): |
| 112 | + if self.model_name in VOC_MODELS: |
| 113 | + # VOC model |
| 114 | + self.preprocess_fn = get_preprocess_fn( |
| 115 | + change_channel_order=False, |
| 116 | + input_size=(224, 224), |
| 117 | + hwc_to_chw=True, |
| 118 | + ) |
| 119 | + self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID) |
| 120 | + elif self.model_name in IMAGENET_MODELS + TRANSFORMER_MODELS: |
| 121 | + # Timm ImageNet model |
| 122 | + mean_values = [(item * 255) for item in model_cfg["mean"]] |
| 123 | + scale_values = [(item * 255) for item in model_cfg["std"]] |
| 124 | + self.preprocess_fn = get_preprocess_fn( |
| 125 | + change_channel_order=True, |
| 126 | + input_size=model_cfg["input_size"][1:], |
| 127 | + mean=mean_values, |
| 128 | + std=scale_values, |
| 129 | + hwc_to_chw=True, |
| 130 | + ) |
| 131 | + self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SOFTMAX) |
| 132 | + else: |
| 133 | + raise ValueError(f"Model {self.model_name} is not supported since it's not VOC or ImageNet model.") |
| 134 | + |
| 135 | + def setup_explainer(self, model, explain_method): |
| 136 | + explain_mode = ExplainMode.WHITEBOX if explain_method in WhiteBoxXAIMethods else ExplainMode.BLACKBOX |
| 137 | + |
| 138 | + if self.model_name in TRANSFORMER_MODELS and explain_method == Method.RECIPROCAM: |
| 139 | + explain_method = Method.VITRECIPROCAM |
125 | 140 |
|
126 |
| - def test_explainer_images(self): |
127 |
| - images, explanations, dataset_gt_bboxes = [], [], [] |
128 |
| - for image, anns in self.dataset: |
129 |
| - image_np = np.array(image) |
130 |
| - gt_bbox_dict = self.anns_to_gt_bboxes(anns, self.dataset_labels_dict) |
131 |
| - targets = [target for target in gt_bbox_dict.keys() if target in VOC_NAMES] |
| 141 | + self.explainer = Explainer( |
| 142 | + model=model, |
| 143 | + task=Task.CLASSIFICATION, |
| 144 | + preprocess_fn=self.preprocess_fn, |
| 145 | + postprocess_fn=self.postprocess_fn, |
| 146 | + explain_mode=explain_mode, |
| 147 | + explain_method=explain_method, |
| 148 | + embed_scaling=True, |
| 149 | + ) |
| 150 | + kwargs = {} |
| 151 | + if explain_method in BlackBoxXAIMethods: |
| 152 | + # TODO: Make Preset configurable as well |
| 153 | + kwargs.update({"preset": Preset.SPEED}) |
| 154 | + return kwargs |
| 155 | + |
| 156 | + @pytest.fixture(autouse=True) |
| 157 | + def setup(self, fxt_data_root, fxt_output_root, fxt_dataset_parameters): |
| 158 | + self.data_dir = fxt_data_root |
| 159 | + self.output_dir = fxt_output_root |
| 160 | + self.supported_num_classes = {1000: 1000} |
132 | 161 |
|
133 |
| - explanation = self.explainer(image_np, targets=targets, label_names=VOC_NAMES, colormap=False) |
| 162 | + self.setup_dataset(fxt_dataset_parameters) |
| 163 | + self.dataset_name = self.dataset_type.value |
134 | 164 |
|
135 |
| - images.append(image_np) |
136 |
| - explanations.append(explanation) |
137 |
| - dataset_gt_bboxes.append({key: value for key, value in gt_bbox_dict.items() if key in targets}) |
| 165 | + @pytest.mark.parametrize("model_id", TEST_MODELS) |
| 166 | + @pytest.mark.parametrize("explain_method", EXPLAIN_METHODS) |
| 167 | + def test_explainer_images(self, model_id, explain_method): |
| 168 | + self.model_name = model_id |
| 169 | + self.data_metric_path = self.output_dir / self.model_name / explain_method.value |
| 170 | + os.makedirs(self.data_metric_path, exist_ok=True) |
138 | 171 |
|
139 |
| - pointing_game = self.pointing_game.evaluate(explanations, dataset_gt_bboxes) |
140 |
| - auc = self.auc.evaluate(explanations, images, steps=10) |
141 |
| - adcc = self.adcc.evaluate(explanations, images) |
| 172 | + model, model_cfg = self.setup_model(self.data_dir, self.model_name) |
| 173 | + self.setup_process_fn(model_cfg) |
| 174 | + black_box_kwargs = self.setup_explainer(model, explain_method) |
142 | 175 |
|
143 |
| - return {**pointing_game, **auc, **adcc} |
| 176 | + self.pointing_game = PointingGame() |
| 177 | + self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn) |
| 178 | + self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer, **black_box_kwargs) |
| 179 | + |
| 180 | + records = [] |
| 181 | + explained_images = 0 |
| 182 | + experiment_start_time = time() |
| 183 | + batch_size = 1000 |
| 184 | + |
| 185 | + for lrange in tqdm(range(0, batch_size), desc="Processing batches"): |
| 186 | + rrange = min(len(self.dataset), lrange + batch_size) |
| 187 | + |
| 188 | + start_time = time() |
| 189 | + images, explanations, dataset_gt_bboxes = [], [], [] |
| 190 | + for i in range(lrange, rrange): |
| 191 | + image, anns = self.dataset[i] |
| 192 | + image_np = np.array(image) # PIL -> np.array |
| 193 | + gt_bbox_dict = self.anns_to_gt_bboxes(anns, self.dataset_labels_dict) |
| 194 | + |
| 195 | + # To measure the quality of predicted saliency maps without the gt info from the dataset (found out how to check it) |
| 196 | + # targets = np.argmax(self.model_predict(image_np)) |
| 197 | + targets = list(gt_bbox_dict.keys()) |
| 198 | + intersected_targets = list(set(targets) & set(self.dataset_label_list)) |
| 199 | + if len(intersected_targets) == 0: |
| 200 | + # Skip images where gt classes and model classes do not match |
| 201 | + continue |
| 202 | + explanation = self.explainer( |
| 203 | + image_np, |
| 204 | + targets=intersected_targets, |
| 205 | + label_names=self.dataset_label_list, |
| 206 | + colormap=False, |
| 207 | + **black_box_kwargs, |
| 208 | + ) |
| 209 | + images.append(image_np) |
| 210 | + explanations.append(explanation) |
| 211 | + dataset_gt_bboxes.append(gt_bbox_dict) |
| 212 | + |
| 213 | + # Write per-batch statistics to track failures |
| 214 | + explained_images += len(explanations) |
| 215 | + record = {"range": f"{lrange}-{rrange}"} |
| 216 | + record.update(self.get_xai_metrics(explanations, images, dataset_gt_bboxes, start_time)) |
| 217 | + records.append(record) |
| 218 | + |
| 219 | + df = pd.DataFrame([record]).round(3) |
| 220 | + df.to_csv(self.data_metric_path / f"accuracy_{self.dataset_name}.csv", mode="a", header=False, index=False) |
| 221 | + |
| 222 | + experiment_time = time() - experiment_start_time |
| 223 | + mean_scores_dict = {"explained_images": explained_images, "overall_time": experiment_time} |
| 224 | + mean_scores_dict.update( |
| 225 | + { |
| 226 | + key: np.mean([record[key] for record in records if key in record]) |
| 227 | + for key in records[0].keys() |
| 228 | + if key != "range" |
| 229 | + } |
| 230 | + ) |
| 231 | + df = pd.DataFrame([mean_scores_dict]).round(3) |
| 232 | + df.to_csv(self.data_metric_path / f"mean_accuracy_{self.dataset_name}.csv", index=False) |
| 233 | + |
| 234 | + def get_xai_metrics( |
| 235 | + self, |
| 236 | + explanations: list[Explanation], |
| 237 | + images: list[np.ndarray], |
| 238 | + dataset_gt_bboxes: Dict[str, List[Tuple[int, int, int, int]]], |
| 239 | + start_time: float, |
| 240 | + ): |
| 241 | + score = {} |
| 242 | + if len(explanations) == 0: |
| 243 | + return score |
| 244 | + |
| 245 | + def evaluate_metric_time(metric_name, evaluation_func, *args, **kwargs): |
| 246 | + previous_time = time() |
| 247 | + score.update(evaluation_func(*args, **kwargs)) |
| 248 | + score[f"{metric_name}_time"] = time() - previous_time |
| 249 | + |
| 250 | + score["explain_time"] = time() - start_time |
| 251 | + evaluate_metric_time("pointing_game", self.pointing_game.evaluate, explanations, dataset_gt_bboxes) |
| 252 | + evaluate_metric_time("auc", self.auc.evaluate, explanations, images, steps=30) |
| 253 | + evaluate_metric_time("adcc", self.adcc.evaluate, explanations, images) |
| 254 | + |
| 255 | + return score |
0 commit comments