Skip to content

Commit b28d731

Browse files
committed
Update documentation
1 parent 3f59b7f commit b28d731

File tree

1 file changed

+219
-107
lines changed

1 file changed

+219
-107
lines changed

releases/1.1.0/_downloads/959b21439550740fb32540e68430cfbe/test_accuracy.py

+219-107
Original file line numberDiff line numberDiff line change
@@ -2,142 +2,254 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import os
5-
from enum import Enum
6-
from typing import Any, Dict, List, Tuple
5+
import random
6+
from pathlib import Path
7+
from time import time
8+
from typing import Dict, List, Tuple
79

810
import numpy as np
911
import openvino as ov
12+
import pandas as pd
1013
import pytest
14+
from tqdm import tqdm
1115

1216
from openvino_xai import Task
17+
from openvino_xai.common.parameters import (
18+
BlackBoxXAIMethods,
19+
Method,
20+
Task,
21+
WhiteBoxXAIMethods,
22+
)
1323
from openvino_xai.common.utils import retrieve_otx_model
1424
from openvino_xai.explainer.explainer import Explainer, ExplainMode
25+
from openvino_xai.explainer.explanation import Explanation
1526
from openvino_xai.explainer.utils import (
1627
ActivationType,
1728
get_postprocess_fn,
1829
get_preprocess_fn,
1930
)
31+
from openvino_xai.methods.black_box.base import Preset
2032
from openvino_xai.metrics import ADCC, InsertionDeletionAUC, PointingGame
21-
from tests.unit.explanation.test_explanation_utils import VOC_NAMES
33+
from tests.perf.perf_tests_utils import convert_timm_to_ir
34+
from tests.test_suite.custom_dataset import CustomVOCDetection
35+
from tests.test_suite.dataset_utils import (
36+
DatasetType,
37+
coco_anns_to_gt_bboxes,
38+
define_dataset_type,
39+
voc_anns_to_gt_bboxes,
40+
)
41+
from tests.unit.explainer.test_explanation_utils import VOC_NAMES, get_imagenet_labels
2242

2343
datasets = pytest.importorskip("torchvision.datasets")
44+
timm = pytest.importorskip("timm")
45+
torch = pytest.importorskip("torch")
46+
47+
48+
IMAGENET_MODELS = [
49+
"resnet18.a1_in1k",
50+
# "resnet50.a1_in1k",
51+
# "resnext50_32x4d.a1h_in1k",
52+
# "vgg16.tv_in1k"
53+
]
54+
VOC_MODELS = [
55+
# "mlc_mobilenetv3_large_voc"
56+
]
57+
TRANSFORMER_MODELS = [
58+
"deit_tiny_patch16_224.fb_in1k", # Downloads last month 8,377
59+
# "deit_base_patch16_224.fb_in1k", # Downloads last month 6,323
60+
# "vit_tiny_patch16_224.augreg_in21k", # Downloads last month 3,671 - trained on ImageNet-21k
61+
# "vit_base_patch16_224.augreg2_in21k_ft_in1k", # Downloads last month 207,590 - trained on ImageNet-21k
62+
]
63+
64+
TEST_MODELS = IMAGENET_MODELS + VOC_MODELS + TRANSFORMER_MODELS
65+
IMAGENET_LABELS = get_imagenet_labels()
66+
EXPLAIN_METHODS = [Method.RECIPROCAM, Method.AISE, Method.RISE, Method.ACTIVATIONMAP]
2467

2568

26-
class DatasetType(Enum):
27-
COCO = "coco"
28-
VOC = "voc"
29-
30-
31-
def coco_anns_to_gt_bboxes(
32-
anns: List[Dict[str, Any]] | Dict[str, Any], coco_val_labels: Dict[int, str]
33-
) -> Dict[str, List[Tuple[int, int, int, int]]]:
34-
gt_bboxes = {}
35-
for ann in anns:
36-
category_id = ann["category_id"]
37-
category_name = coco_val_labels[category_id]
38-
bbox = ann["bbox"]
39-
if category_name not in gt_bboxes:
40-
gt_bboxes[category_name] = []
41-
gt_bboxes[category_name].append(bbox)
42-
return gt_bboxes
43-
44-
45-
def voc_anns_to_gt_bboxes(
46-
anns: List[Dict[str, Any]] | Dict[str, Any], *args: Any
47-
) -> Dict[str, List[Tuple[int, int, int, int]]]:
48-
gt_bboxes = {}
49-
anns = anns["annotation"]["object"]
50-
for ann in anns:
51-
category_name = ann["name"]
52-
bndbox = list(map(float, ann["bndbox"].values()))
53-
bndbox = np.array(bndbox, dtype=np.int32)
54-
x_min, y_min, x_max, y_max = bndbox
55-
bbox = (x_min, y_min, x_max - x_min, y_max - y_min)
56-
57-
if category_name not in gt_bboxes:
58-
gt_bboxes[category_name] = []
59-
gt_bboxes[category_name].append(bbox)
60-
return gt_bboxes
61-
62-
63-
def define_dataset_type(data_root: str, ann_path: str) -> DatasetType:
64-
if data_root and ann_path and ann_path.lower().endswith(".json"):
65-
if any(image_name.endswith(".jpg") for image_name in os.listdir(data_root)):
66-
return DatasetType.COCO
67-
68-
required_voc_dirs = {"JPEGImages", "SegmentationObject", "ImageSets", "Annotations", "SegmentationClass"}
69-
for _, dir, _ in os.walk(data_root):
70-
if required_voc_dirs.issubset(set(dir)):
71-
return DatasetType.VOC
72-
73-
raise ValueError("Dataset type is not supported")
74-
75-
76-
@pytest.mark.parametrize(
77-
"data_root, ann_path",
78-
[
79-
("tests/assets/cheetah_coco/images/val", "tests/assets/cheetah_coco/annotations/instances_val.json"),
80-
("tests/assets/cheetah_voc", None),
81-
],
82-
)
8369
class TestAccuracy:
84-
MODEL_NAME = "mlc_mobilenetv3_large_voc"
85-
86-
@pytest.fixture(autouse=True)
87-
def setup(self, fxt_data_root, data_root, ann_path):
88-
data_dir = fxt_data_root
89-
retrieve_otx_model(data_dir, self.MODEL_NAME)
90-
model_path = data_dir / "otx_models" / (self.MODEL_NAME + ".xml")
91-
model = ov.Core().read_model(model_path)
92-
93-
self.setup_dataset(data_root, ann_path)
94-
95-
self.preprocess_fn = get_preprocess_fn(
96-
change_channel_order=self.channel_format == "BGR",
97-
input_size=(224, 224),
98-
hwc_to_chw=True,
99-
)
100-
self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)
101-
102-
self.explainer = Explainer(
103-
model=model,
104-
task=Task.CLASSIFICATION,
105-
preprocess_fn=self.preprocess_fn,
106-
explain_mode=ExplainMode.WHITEBOX,
107-
)
108-
109-
self.pointing_game = PointingGame()
110-
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)
111-
self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer)
70+
def setup_dataset(self, dataset_parameters: List[Tuple[Path, Path | None]]):
71+
if dataset_parameters == (None, None):
72+
data_root, ann_path = Path("tests/assets/cheetah_voc"), None
73+
else:
74+
data_root, ann_path = dataset_parameters
11275

113-
def setup_dataset(self, data_root: str, ann_path: str):
11476
self.dataset_type = define_dataset_type(data_root, ann_path)
115-
self.channel_format = "RGB" if self.dataset_type in [DatasetType.VOC, DatasetType.COCO] else "None"
116-
11777
if self.dataset_type == DatasetType.COCO:
11878
self.dataset = datasets.CocoDetection(root=data_root, annFile=ann_path)
11979
self.dataset_labels_dict = {cats["id"]: cats["name"] for cats in self.dataset.coco.cats.values()}
12080
self.anns_to_gt_bboxes = coco_anns_to_gt_bboxes
121-
elif self.dataset_type == DatasetType.VOC:
122-
self.dataset = datasets.VOCDetection(root=data_root, download=False, year="2012", image_set="val")
81+
elif self.dataset_type in [DatasetType.VOC, DatasetType.ILSVRC]:
82+
self.dataset = CustomVOCDetection(root=data_root, download=False, year="2012", image_set="val")
12383
self.dataset_labels_dict = None
12484
self.anns_to_gt_bboxes = voc_anns_to_gt_bboxes
85+
self.dataset = self.subset_dataset(num_samples=5000, seed=42)
86+
87+
def subset_dataset(self, num_samples=-1, seed=42):
88+
if num_samples == -1 or num_samples >= len(self.dataset):
89+
return self.dataset
90+
random.seed(seed)
91+
subset_indices = random.sample(range(len(self.dataset)), num_samples)
92+
return torch.utils.data.Subset(self.dataset, subset_indices)
93+
94+
def setup_model(self, data_dir, model_name):
95+
if model_name in VOC_MODELS:
96+
self.dataset_label_list = VOC_NAMES
97+
retrieve_otx_model(data_dir, model_name)
98+
model_path = data_dir / "otx_models" / (model_name + ".xml")
99+
model = ov.Core().read_model(model_path)
100+
return model, None
101+
102+
elif model_name in IMAGENET_MODELS + TRANSFORMER_MODELS:
103+
self.dataset_label_list = IMAGENET_LABELS
104+
_, model_cfg = convert_timm_to_ir(model_name, data_dir, self.supported_num_classes)
105+
ir_path = data_dir / "timm_models" / "converted_models" / model_name / "model_fp32.xml"
106+
model = ov.Core().read_model(ir_path)
107+
return model, model_cfg
108+
else:
109+
raise ValueError(f"Model {model_name} is not supported since it's not VOC or ImageNet model.")
110+
111+
def setup_process_fn(self, model_cfg):
112+
if self.model_name in VOC_MODELS:
113+
# VOC model
114+
self.preprocess_fn = get_preprocess_fn(
115+
change_channel_order=False,
116+
input_size=(224, 224),
117+
hwc_to_chw=True,
118+
)
119+
self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)
120+
elif self.model_name in IMAGENET_MODELS + TRANSFORMER_MODELS:
121+
# Timm ImageNet model
122+
mean_values = [(item * 255) for item in model_cfg["mean"]]
123+
scale_values = [(item * 255) for item in model_cfg["std"]]
124+
self.preprocess_fn = get_preprocess_fn(
125+
change_channel_order=True,
126+
input_size=model_cfg["input_size"][1:],
127+
mean=mean_values,
128+
std=scale_values,
129+
hwc_to_chw=True,
130+
)
131+
self.postprocess_fn = get_postprocess_fn(activation=ActivationType.SOFTMAX)
132+
else:
133+
raise ValueError(f"Model {self.model_name} is not supported since it's not VOC or ImageNet model.")
134+
135+
def setup_explainer(self, model, explain_method):
136+
explain_mode = ExplainMode.WHITEBOX if explain_method in WhiteBoxXAIMethods else ExplainMode.BLACKBOX
137+
138+
if self.model_name in TRANSFORMER_MODELS and explain_method == Method.RECIPROCAM:
139+
explain_method = Method.VITRECIPROCAM
125140

126-
def test_explainer_images(self):
127-
images, explanations, dataset_gt_bboxes = [], [], []
128-
for image, anns in self.dataset:
129-
image_np = np.array(image)
130-
gt_bbox_dict = self.anns_to_gt_bboxes(anns, self.dataset_labels_dict)
131-
targets = [target for target in gt_bbox_dict.keys() if target in VOC_NAMES]
141+
self.explainer = Explainer(
142+
model=model,
143+
task=Task.CLASSIFICATION,
144+
preprocess_fn=self.preprocess_fn,
145+
postprocess_fn=self.postprocess_fn,
146+
explain_mode=explain_mode,
147+
explain_method=explain_method,
148+
embed_scaling=True,
149+
)
150+
kwargs = {}
151+
if explain_method in BlackBoxXAIMethods:
152+
# TODO: Make Preset configurable as well
153+
kwargs.update({"preset": Preset.SPEED})
154+
return kwargs
155+
156+
@pytest.fixture(autouse=True)
157+
def setup(self, fxt_data_root, fxt_output_root, fxt_dataset_parameters):
158+
self.data_dir = fxt_data_root
159+
self.output_dir = fxt_output_root
160+
self.supported_num_classes = {1000: 1000}
132161

133-
explanation = self.explainer(image_np, targets=targets, label_names=VOC_NAMES, colormap=False)
162+
self.setup_dataset(fxt_dataset_parameters)
163+
self.dataset_name = self.dataset_type.value
134164

135-
images.append(image_np)
136-
explanations.append(explanation)
137-
dataset_gt_bboxes.append({key: value for key, value in gt_bbox_dict.items() if key in targets})
165+
@pytest.mark.parametrize("model_id", TEST_MODELS)
166+
@pytest.mark.parametrize("explain_method", EXPLAIN_METHODS)
167+
def test_explainer_images(self, model_id, explain_method):
168+
self.model_name = model_id
169+
self.data_metric_path = self.output_dir / self.model_name / explain_method.value
170+
os.makedirs(self.data_metric_path, exist_ok=True)
138171

139-
pointing_game = self.pointing_game.evaluate(explanations, dataset_gt_bboxes)
140-
auc = self.auc.evaluate(explanations, images, steps=10)
141-
adcc = self.adcc.evaluate(explanations, images)
172+
model, model_cfg = self.setup_model(self.data_dir, self.model_name)
173+
self.setup_process_fn(model_cfg)
174+
black_box_kwargs = self.setup_explainer(model, explain_method)
142175

143-
return {**pointing_game, **auc, **adcc}
176+
self.pointing_game = PointingGame()
177+
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)
178+
self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer, **black_box_kwargs)
179+
180+
records = []
181+
explained_images = 0
182+
experiment_start_time = time()
183+
batch_size = 1000
184+
185+
for lrange in tqdm(range(0, batch_size), desc="Processing batches"):
186+
rrange = min(len(self.dataset), lrange + batch_size)
187+
188+
start_time = time()
189+
images, explanations, dataset_gt_bboxes = [], [], []
190+
for i in range(lrange, rrange):
191+
image, anns = self.dataset[i]
192+
image_np = np.array(image) # PIL -> np.array
193+
gt_bbox_dict = self.anns_to_gt_bboxes(anns, self.dataset_labels_dict)
194+
195+
# To measure the quality of predicted saliency maps without the gt info from the dataset (found out how to check it)
196+
# targets = np.argmax(self.model_predict(image_np))
197+
targets = list(gt_bbox_dict.keys())
198+
intersected_targets = list(set(targets) & set(self.dataset_label_list))
199+
if len(intersected_targets) == 0:
200+
# Skip images where gt classes and model classes do not match
201+
continue
202+
explanation = self.explainer(
203+
image_np,
204+
targets=intersected_targets,
205+
label_names=self.dataset_label_list,
206+
colormap=False,
207+
**black_box_kwargs,
208+
)
209+
images.append(image_np)
210+
explanations.append(explanation)
211+
dataset_gt_bboxes.append(gt_bbox_dict)
212+
213+
# Write per-batch statistics to track failures
214+
explained_images += len(explanations)
215+
record = {"range": f"{lrange}-{rrange}"}
216+
record.update(self.get_xai_metrics(explanations, images, dataset_gt_bboxes, start_time))
217+
records.append(record)
218+
219+
df = pd.DataFrame([record]).round(3)
220+
df.to_csv(self.data_metric_path / f"accuracy_{self.dataset_name}.csv", mode="a", header=False, index=False)
221+
222+
experiment_time = time() - experiment_start_time
223+
mean_scores_dict = {"explained_images": explained_images, "overall_time": experiment_time}
224+
mean_scores_dict.update(
225+
{
226+
key: np.mean([record[key] for record in records if key in record])
227+
for key in records[0].keys()
228+
if key != "range"
229+
}
230+
)
231+
df = pd.DataFrame([mean_scores_dict]).round(3)
232+
df.to_csv(self.data_metric_path / f"mean_accuracy_{self.dataset_name}.csv", index=False)
233+
234+
def get_xai_metrics(
235+
self,
236+
explanations: list[Explanation],
237+
images: list[np.ndarray],
238+
dataset_gt_bboxes: Dict[str, List[Tuple[int, int, int, int]]],
239+
start_time: float,
240+
):
241+
score = {}
242+
if len(explanations) == 0:
243+
return score
244+
245+
def evaluate_metric_time(metric_name, evaluation_func, *args, **kwargs):
246+
previous_time = time()
247+
score.update(evaluation_func(*args, **kwargs))
248+
score[f"{metric_name}_time"] = time() - previous_time
249+
250+
score["explain_time"] = time() - start_time
251+
evaluate_metric_time("pointing_game", self.pointing_game.evaluate, explanations, dataset_gt_bboxes)
252+
evaluate_metric_time("auc", self.auc.evaluate, explanations, images, steps=30)
253+
evaluate_metric_time("adcc", self.adcc.evaluate, explanations, images)
254+
255+
return score

0 commit comments

Comments
 (0)