Skip to content

Commit a7b16fa

Browse files
author
Evgeny Tsykunov
authoredSep 20, 2024
Add reference tests with full map (#70)
* Add reference test for resnet * fix black * fix path * minor * numerical stability
1 parent 89ffabf commit a7b16fa

6 files changed

+62
-1
lines changed
 

‎openvino_xai/explainer/explanation.py

-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def save(
177177
:type postfix: str
178178
:param confidence_scores: Dict with confidence scores for each class index. Default is None.
179179
:type confidence_scores: Dict[int, float] | None
180-
181180
"""
182181

183182
os.makedirs(dir_path, exist_ok=True)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

‎tests/intg/test_classification_timm.py

+62
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ class TestImageClassificationTimm:
144144
21843: 2441, # 2441 is a cheetah class_id in the ImageNet-21k dataset
145145
11821: 1652, # 1652 is a cheetah class_id in the ImageNet-12k dataset
146146
}
147+
reference_maps_names = {
148+
(ExplainMode.WHITEBOX, Method.RECIPROCAM): Path("resnet18.a1_in1k_reciprocam.npy"),
149+
(ExplainMode.WHITEBOX, Method.ACTIVATIONMAP): Path("resnet18.a1_in1k_activationmap.npy"),
150+
(ExplainMode.BLACKBOX, Method.AISE): Path("resnet18.a1_in1k_aise.npy"),
151+
(ExplainMode.BLACKBOX, Method.RISE): Path("resnet18.a1_in1k_rise.npy"),
152+
}
147153

148154
@pytest.fixture(autouse=True)
149155
def setup(self, fxt_data_root, fxt_output_root, fxt_clear_cache):
@@ -504,6 +510,62 @@ def test_torch_insert_xai_with_layer(self, model_id: str, detect: str):
504510

505511
self.clear_cache()
506512

513+
@pytest.mark.parametrize(
514+
"explain_mode, explain_method",
515+
[
516+
(ExplainMode.WHITEBOX, Method.RECIPROCAM),
517+
(ExplainMode.WHITEBOX, Method.ACTIVATIONMAP),
518+
(ExplainMode.BLACKBOX, Method.AISE),
519+
(ExplainMode.BLACKBOX, Method.RISE),
520+
],
521+
)
522+
def test_reference_map(self, explain_mode, explain_method):
523+
model_id = "resnet18.a1_in1k"
524+
model_dir = self.data_dir / "timm_models" / "converted_models"
525+
_, model_cfg = self.get_timm_model(model_id, model_dir)
526+
527+
ir_path = model_dir / model_id / "model_fp32.xml"
528+
model = ov.Core().read_model(ir_path)
529+
530+
mean_values = [(item * 255) for item in model_cfg["mean"]]
531+
scale_values = [(item * 255) for item in model_cfg["std"]]
532+
preprocess_fn = get_preprocess_fn(
533+
change_channel_order=True,
534+
input_size=model_cfg["input_size"][1:],
535+
mean=mean_values,
536+
std=scale_values,
537+
hwc_to_chw=True,
538+
)
539+
540+
explainer = Explainer(
541+
model=model,
542+
task=Task.CLASSIFICATION,
543+
preprocess_fn=preprocess_fn,
544+
postprocess_fn=get_postprocess_fn(),
545+
explain_mode=explain_mode,
546+
explain_method=explain_method,
547+
embed_scaling=False,
548+
)
549+
550+
target_class = self.supported_num_classes[model_cfg["num_classes"]]
551+
image = cv2.imread("tests/assets/cheetah_person.jpg")
552+
explanation = explainer(
553+
image,
554+
original_input_image=image,
555+
targets=[target_class],
556+
resize=False,
557+
colormap=False,
558+
)
559+
560+
if explain_method == Method.ACTIVATIONMAP:
561+
generated_map = explanation.saliency_map["per_image_map"]
562+
else:
563+
generated_map = explanation.saliency_map[target_class]
564+
565+
reference_maps_path = Path("tests/assets/reference_maps")
566+
reference_map = np.load(reference_maps_path / self.reference_maps_names[(explain_mode, explain_method)])
567+
assert np.all(np.abs(generated_map.astype(np.int16) - reference_map.astype(np.int16)) <= 3)
568+
507569
def check_for_saved_map(self, model_id, directory):
508570
for target in self.supported_num_classes.values():
509571
map_name = model_id + "_target_" + str(target) + ".jpg"

0 commit comments

Comments
 (0)
Please sign in to comment.