Skip to content

Commit 4ce1903

Browse files
Add ADCC metric (#57)
* Draft pointing game implementation * Add insertion deletion auc * Add ADCC * Update auc * Introduce BaseMetric as a parent class * Delete ADCC * Remove adcc tests * Fixes from comments * Add ADCC * Remove scaling logic * Add extra unit test * Update threshold value * Update Changelog
1 parent 4e39758 commit 4ce1903

File tree

8 files changed

+252
-13
lines changed

8 files changed

+252
-13
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
* Upgrade OpenVINO to 2024.3.0
1010
* Add saliency map visualization with explanation.plot()
1111
* Enable flexible naming for saved saliency maps and include confidence scores
12+
* Add Pointing Game, Insertion-Deletion AUC and ADCC quality metrics for saliency maps
1213

1314
### What's Changed
1415

@@ -22,7 +23,8 @@
2223
* Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53
2324
* Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51
2425
* Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
25-
* Add [Insertion Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56
26+
* Add [Insertion-Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56
27+
* Add [ADCC](https://arxiv.org/abs/2104.10252) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/57
2628

2729
### Known Issues
2830

openvino_xai/metrics/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@
33
"""
44
Metrics in OpenVINO-XAI to check the quality of saliency maps.
55
"""
6+
7+
from openvino_xai.metrics.adcc import ADCC
8+
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
9+
from openvino_xai.metrics.pointing_game import PointingGame
10+
11+
__all__ = ["ADCC", "InsertionDeletionAUC", "PointingGame"]

openvino_xai/metrics/adcc.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, List
2+
3+
import numpy as np
4+
from scipy import stats as STS
5+
6+
from openvino_xai import Task
7+
from openvino_xai.explainer.explainer import Explainer, ExplainMode
8+
from openvino_xai.explainer.explanation import Explanation
9+
from openvino_xai.metrics.base import BaseMetric
10+
11+
12+
class ADCC(BaseMetric):
13+
"""
14+
Implementation of the e Average Drop-Coherence-Complexity (ADCC) metric by Poppi, Samuele, et al 2021.
15+
16+
References:
17+
Poppi, Samuele, et al. "Revisiting the evaluation of class activation mapping for explainability:
18+
A novel metric and experimental analysis." Proceedings of the IEEE/CVF Conference on
19+
Computer Vision and Pattern Recognition. 2021.
20+
"""
21+
22+
def __init__(self, model, preprocess_fn, postprocess_fn, explainer=None, device_name="CPU"):
23+
super().__init__(
24+
model=model, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, device_name=device_name
25+
)
26+
if explainer is None:
27+
self.explainer = Explainer(
28+
model=model,
29+
task=Task.CLASSIFICATION,
30+
preprocess_fn=self.preprocess_fn,
31+
explain_mode=ExplainMode.WHITEBOX,
32+
)
33+
else:
34+
self.explainer = explainer
35+
36+
def average_drop(
37+
self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray, model_output: np.ndarray
38+
) -> float:
39+
"""
40+
Measures the average percentage drop in confidence for the target class when the model sees only the
41+
explanation map (image masked with saliency map), instead of the full image.
42+
The less the better.
43+
"""
44+
confidence_on_input = np.max(model_output)
45+
46+
masked_image = (image * saliency_map[:, :, None]).astype(np.uint8)
47+
prediction_on_saliency_map = self.model_predict(masked_image)
48+
confidence_on_saliency_map = prediction_on_saliency_map[class_idx]
49+
50+
return max(0.0, confidence_on_input - confidence_on_saliency_map) / confidence_on_input
51+
52+
def coherency(self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray) -> float:
53+
"""
54+
Measures the coherency of the saliency map. The explanation map (image masked with saliency map) should contain all the relevant features that explain a prediction and should remove useless features in a coherent way.
55+
Saliency map and saliency map of exlanation map should be similar.
56+
The more the better.
57+
"""
58+
59+
masked_image = image * saliency_map[:, :, None]
60+
saliency_map_mapped_image = self.explainer(masked_image, targets=[class_idx], colormap=False, scaling=False)
61+
saliency_map_mapped_image = saliency_map_mapped_image.saliency_map[class_idx]
62+
63+
A, B = saliency_map, saliency_map_mapped_image
64+
# Pearson correlation coefficient
65+
Asq, Bsq = A.flatten(), B.flatten()
66+
y, _ = STS.pearsonr(Asq, Bsq)
67+
y = (y + 1) / 2
68+
69+
return y
70+
71+
@staticmethod
72+
def complexity(saliency_map: np.ndarray) -> float:
73+
"""
74+
Measures the complexity of the saliency map. Less important pixels -> less complexity.
75+
Defined as L1 norm of the saliency map.
76+
The less the better.
77+
"""
78+
return abs(saliency_map).sum() / (saliency_map.shape[-1] * saliency_map.shape[-2])
79+
80+
def __call__(self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray) -> Dict[str, float]:
81+
"""
82+
Calculate the ADCC metric for a given saliency map and class index.
83+
The more the better.
84+
85+
Parameters:
86+
:param saliency_map: Saliency map for class_idx class (H, W).
87+
:type saliency_map: np.ndarray
88+
:param class_idx: The class index of saliency map.
89+
:type class_idx: int
90+
:param input_image: The input image to the model (H, W, C).
91+
:type input_image: np.ndarray
92+
93+
Returns:
94+
:return: A dictionary containing the ADCC, coherency, complexity, and average drop metrics.
95+
:rtype: Dict[str, float]
96+
"""
97+
if not (0 <= np.min(saliency_map) and np.max(saliency_map) <= 1):
98+
# Scale saliency map to [0, 1]
99+
saliency_map = saliency_map / 255
100+
101+
model_output = self.model_predict(input_image)
102+
103+
avgdrop = self.average_drop(saliency_map, class_idx, input_image, model_output)
104+
coh = self.coherency(saliency_map, class_idx, input_image)
105+
com = self.complexity(saliency_map)
106+
107+
adcc = 3 / (1 / coh + 1 / (1 - com) + 1 / (1 - avgdrop))
108+
return {"adcc": adcc, "coherency": coh, "complexity": com, "average_drop": avgdrop}
109+
110+
def evaluate(
111+
self, explanations: List[Explanation], input_images: List[np.ndarray], **kwargs: Any
112+
) -> Dict[str, float]:
113+
"""
114+
Evaluate the ADCC metric over a list of explanations and input images.
115+
116+
Parameters:
117+
:param explanations: A list of explanations for each image.
118+
:type explanations: List[Explanation]
119+
:param input_images: A list of input images.
120+
:type input_images: List[np.ndarray]
121+
122+
Returns:
123+
:return: A dictionary containing the average ADCC score.
124+
:rtype: Dict[str, float]
125+
"""
126+
results = []
127+
for input_image, explanation in zip(input_images, explanations):
128+
for class_idx, saliency_map in explanation.saliency_map.items():
129+
metric_dict = self(saliency_map, int(class_idx), input_image)
130+
results.append(metric_dict["adcc"])
131+
adcc = np.mean(np.array(results), axis=0)
132+
return {"adcc": adcc}

openvino_xai/metrics/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ class BaseMetric(ABC):
1313

1414
def __init__(
1515
self,
16-
model_compiled: ov.CompiledModel = None,
16+
model: ov.Model = None,
1717
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
1818
postprocess_fn: Callable[[np.ndarray], np.ndarray] = None,
19+
device_name: str = "CPU",
1920
):
2021
# Pass model_predict to class initialization directly?
21-
self.model_compiled = model_compiled
22+
self.model = model
23+
self.model_compiled = ov.Core().compile_model(model=model, device_name=device_name)
2224
self.preprocess_fn = preprocess_fn
2325
self.postprocess_fn = postprocess_fn
2426

openvino_xai/metrics/pointing_game.py

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class PointingGame(BaseMetric):
2929
(2018) 126:1084-1102.
3030
"""
3131

32+
def __init__(self):
33+
pass
34+
3235
@staticmethod
3336
def __call__(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> Dict[str, float]:
3437
"""

tests/regression/test_regression.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
get_postprocess_fn,
1717
get_preprocess_fn,
1818
)
19-
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
20-
from openvino_xai.metrics.pointing_game import PointingGame
19+
from openvino_xai.metrics import ADCC, InsertionDeletionAUC, PointingGame
2120
from tests.unit.explanation.test_explanation_utils import VOC_NAMES
2221

2322
MODEL_NAME = "mlc_mobilenetv3_large_voc"
@@ -57,7 +56,6 @@ def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, i
5756
class TestDummyRegression:
5857
image = cv2.imread(IMAGE_PATH)
5958
gt_bboxes = load_gt_bboxes(COCO_ANN_PATH)
60-
pointing_game = PointingGame()
6159

6260
preprocess_fn = get_preprocess_fn(
6361
change_channel_order=True,
@@ -73,8 +71,6 @@ def setup(self, fxt_data_root):
7371
retrieve_otx_model(data_dir, MODEL_NAME)
7472
model_path = data_dir / "otx_models" / (MODEL_NAME + ".xml")
7573
model = ov.Core().read_model(model_path)
76-
compiled_model = ov.Core().compile_model(model, "CPU")
77-
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)
7874

7975
self.explainer = Explainer(
8076
model=model,
@@ -83,29 +79,42 @@ def setup(self, fxt_data_root):
8379
explain_mode=ExplainMode.WHITEBOX,
8480
)
8581

82+
self.pointing_game = PointingGame()
83+
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)
84+
self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer)
85+
8686
def test_explainer_image(self):
8787
explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
8888
assert len(explanation.saliency_map) == 1
89+
8990
pointing_game_score = self.pointing_game.evaluate([explanation], self.gt_bboxes)["pointing_game"]
9091
assert pointing_game_score == 1.0
9192

92-
explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
93-
assert len(explanation.saliency_map) == 1
9493
auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values()
9594
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
9695
assert insertion_auc_score >= 0.9
9796
assert deletion_auc_score >= 0.2
9897
assert delta_auc_score >= 0.7
9998

100-
# Two classes for saliency maps
99+
adcc_score = self.adcc.evaluate([explanation], [self.image])["adcc"]
100+
assert adcc_score > 0.9
101+
102+
def test_explainer_image_2_classes(self):
101103
explanation = self.explainer(self.image, targets=["person", "cat"], label_names=VOC_NAMES, colormap=False)
102104
assert len(explanation.saliency_map) == 2
105+
106+
pointing_game_score = self.pointing_game.evaluate([explanation], self.gt_bboxes)["pointing_game"]
107+
assert pointing_game_score == 1.0
108+
103109
auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values()
104110
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
105111
assert insertion_auc_score >= 0.5
106112
assert deletion_auc_score >= 0.1
107113
assert delta_auc_score >= 0.35
108114

115+
adcc_score = self.adcc.evaluate([explanation], [self.image])["adcc"]
116+
assert adcc_score > 0.5
117+
109118
def test_explainer_images(self):
110119
images = [self.image, self.image]
111120
explanations = []
@@ -122,3 +131,6 @@ def test_explainer_images(self):
122131
assert insertion_auc_score >= 0.9
123132
assert deletion_auc_score >= 0.2
124133
assert delta_auc_score >= 0.7
134+
135+
adcc_score = self.adcc.evaluate(explanations, images)["adcc"]
136+
assert adcc_score > 0.9

tests/unit/metrics/test_adcc.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import json
2+
from typing import Callable, List, Mapping
3+
4+
import cv2
5+
import numpy as np
6+
import openvino as ov
7+
import pytest
8+
9+
from openvino_xai import Task
10+
from openvino_xai.common.utils import retrieve_otx_model
11+
from openvino_xai.explainer.explainer import Explainer, ExplainMode
12+
from openvino_xai.explainer.explanation import Explanation
13+
from openvino_xai.explainer.utils import (
14+
ActivationType,
15+
get_postprocess_fn,
16+
get_preprocess_fn,
17+
sigmoid,
18+
)
19+
from openvino_xai.methods.black_box.base import Preset
20+
from openvino_xai.metrics.adcc import ADCC
21+
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
22+
from openvino_xai.metrics.pointing_game import PointingGame
23+
from tests.unit.explanation.test_explanation_utils import VOC_NAMES
24+
25+
MODEL_NAME = "mlc_mobilenetv3_large_voc"
26+
27+
28+
class TestADCC:
29+
image = cv2.imread("tests/assets/cheetah_person.jpg")
30+
preprocess_fn = get_preprocess_fn(
31+
change_channel_order=True,
32+
input_size=(224, 224),
33+
hwc_to_chw=True,
34+
)
35+
postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)
36+
37+
@pytest.fixture(autouse=True)
38+
def setup(self, fxt_data_root):
39+
self.data_dir = fxt_data_root
40+
retrieve_otx_model(self.data_dir, MODEL_NAME)
41+
model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml")
42+
self.model = ov.Core().read_model(model_path)
43+
self.explainer = Explainer(
44+
model=self.model,
45+
task=Task.CLASSIFICATION,
46+
preprocess_fn=self.preprocess_fn,
47+
explain_mode=ExplainMode.WHITEBOX,
48+
)
49+
self.adcc = ADCC(self.model, self.preprocess_fn, self.postprocess_fn, self.explainer)
50+
51+
def test_adcc_init_wo_explainer(self):
52+
adcc_wo_explainer = ADCC(self.model, self.preprocess_fn, self.postprocess_fn)
53+
assert isinstance(adcc_wo_explainer.explainer, Explainer)
54+
55+
def test_adcc(self):
56+
input_image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
57+
saliency_map = np.random.rand(224, 224)
58+
59+
complexity_score = self.adcc.complexity(saliency_map)
60+
assert complexity_score >= 0.2
61+
62+
model_output = self.adcc.model_predict(input_image)
63+
class_idx = np.argmax(model_output)
64+
65+
average_drop_score = self.adcc.average_drop(saliency_map, class_idx, input_image, model_output)
66+
assert average_drop_score >= 0.2
67+
68+
coherency_score = self.adcc.coherency(saliency_map, class_idx, input_image)
69+
assert coherency_score >= 0.2
70+
71+
adcc_score = self.adcc(saliency_map, class_idx, input_image)["adcc"]
72+
assert adcc_score >= 0.4
73+
74+
def test_evaluate(self):
75+
input_images = [np.random.rand(224, 224, 3) for _ in range(5)]
76+
explanations = [
77+
Explanation({0: np.random.rand(224, 224), 1: np.random.rand(224, 224)}, targets=[0, 1]) for _ in range(5)
78+
]
79+
80+
adcc_score = self.adcc.evaluate(explanations, input_images)["adcc"]
81+
82+
assert isinstance(adcc_score, float)
83+
assert 0 <= adcc_score <= 1

tests/unit/metrics/test_auc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def setup(self, fxt_data_root):
3434
model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml")
3535
core = ov.Core()
3636
model = core.read_model(model_path)
37-
compiled_model = core.compile_model(model=model, device_name="AUTO")
38-
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)
37+
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)
3938

4039
self.explainer = Explainer(
4140
model=model,

0 commit comments

Comments
 (0)