Skip to content

Commit 2bbf89d

Browse files
Add metric section in usage documentation (#75)
Add metrics documentation
1 parent cb99489 commit 2bbf89d

File tree

4 files changed

+98
-15
lines changed

4 files changed

+98
-15
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ Please find more options and scenarios in the following links:
244244
* [OpenVINO XAI User Guide](docs/source/user-guide.md)
245245
* [OpenVINO Notebook - XAI Basic](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/explainable-ai-1-basic/explainable-ai-1-basic.ipynb)
246246
* [OpenVINO Notebook - XAI Deep Dive](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/explainable-ai-2-deep-dive/explainable-ai-2-deep-dive.ipynb)
247+
* [OpenVINO Notebook - Saliency Map Interpretation](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/notebooks/explainable-ai-3-map-interpretation/explainable-ai-3-map-interpretation.ipynb)
247248

248249
### Playing with the examples
249250

docs/source/user-guide.md

+86-10
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ explainer = xai.Explainer(
224224
model,
225225
task=xai.Task.CLASSIFICATION,
226226
preprocess_fn=preprocess_fn,
227+
explain_mode=ExplainMode.WHITEBOX,
227228
)
228229

229230
# Generate and process saliency maps (as many as required, sequentially)
@@ -237,7 +238,6 @@ voc_labels = [
237238
# Run explanation
238239
explanation = explainer(
239240
image,
240-
explain_mode=ExplainMode.WHITEBOX,
241241
# target_layer="last_conv_node_name", # target_layer - node after which the XAI branch will be inserted, usually the last convolutional layer in the backbone
242242
embed_scaling=True, # True by default. If set to True, the saliency map scale (0 ~ 255) operation is embedded in the model
243243
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
@@ -288,6 +288,7 @@ explainer = xai.Explainer(
288288
model,
289289
task=xai.Task.CLASSIFICATION,
290290
preprocess_fn=preprocess_fn,
291+
explain_mode=ExplainMode.BLACKBOX,
291292
)
292293

293294
# Generate and process saliency maps (as many as required, sequentially)
@@ -296,7 +297,6 @@ image = cv2.imread("path/to/image.jpg")
296297
# Run explanation
297298
explanation = explainer(
298299
image,
299-
explain_mode=ExplainMode.BLACKBOX,
300300
targets=[11, 14], # target classes to explain
301301
# targets=-1, # explain all classes
302302
overlay=True, # False by default
@@ -317,15 +317,19 @@ As mentioned above, saliency map generation requires model inference. In the abo
317317
**Note**: The original model outputs are not affected, and the model should be inferable by the original inference pipeline.
318318

319319
```python
320+
import cv2
320321
import openvino.runtime as ov
321322
import openvino_xai as xai
322-
from openvino_xai.common.utils import softmax
323-
from openvino_xai.explainer.visualizer colormap, overlay
323+
from openvino_xai.explainer.visualizer import colormap, overlay
324324

325325

326326
# Create an ov.Model
327327
model: ov.Model = ov.Core().read_model("path/to/model.xml")
328328

329+
# Get and preprocess image
330+
image = cv2.imread("path/to/image.jpg")
331+
image_norm = preprocess_fn(image)
332+
329333
# Insert XAI branch into the OpenVINO model graph (IR)
330334
model_xai: ov.Model = xai.insert_xai(
331335
model=model,
@@ -338,10 +342,11 @@ model_xai: ov.Model = xai.insert_xai(
338342
# Insert XAI branch into the Pytorch model
339343
# XAI head is inserted using the module hook mechanism internally
340344
# so that users could get additional saliency map without major changes in the original inference pipeline.
345+
import torch
341346
model: torch.nn.Module
342347

343348
# Insert XAI head
344-
model_xai: torch.nn.Module = insert_xai(model=model, task=xai.Task.CLASSIFICATION)
349+
model_xai: torch.nn.Module = xai.insert_xai(model=model, task=xai.Task.CLASSIFICATION)
345350

346351
# Torch XAI model inference
347352
model_xai.eval()
@@ -355,9 +360,9 @@ with torch.no_grad():
355360
# Torch XAI model saliency map
356361
saliency_maps = saliency_maps.numpy(force=True).squeeze(0) # Cxhxw
357362
saliency_map = saliency_maps[label] # hxw saliency_map for the label
358-
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
359-
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
360-
result_image = overlay(saliency_map, image)
363+
saliency_map = cv2.resize(saliency_map, dsize=image.shape[::-1]) # HxW
364+
saliency_map = colormap(saliency_map[None, :]) # 1xHxWx3
365+
result_image = overlay(saliency_map, image)[0] # HxWx3
361366
```
362367

363368
## XAI methods
@@ -535,6 +540,7 @@ explainer = xai.Explainer(
535540
model,
536541
task=xai.Task.CLASSIFICATION,
537542
preprocess_fn=preprocess_fn,
543+
explain_mode=ExplainMode.WHITEBOX,
538544
)
539545

540546
voc_labels = [
@@ -548,7 +554,6 @@ image = cv2.imread("path/to/image.jpg")
548554
# Run explanation
549555
explanation = explainer(
550556
image,
551-
explain_mode=ExplainMode.WHITEBOX,
552557
label_names=voc_labels,
553558
targets=[7, 11], # ['cat', 'dog'] also possible as target classes to explain
554559
)
@@ -616,6 +621,7 @@ explainer = xai.Explainer(
616621
model,
617622
task=xai.Task.CLASSIFICATION,
618623
preprocess_fn=preprocess_fn,
624+
explain_mode=ExplainMode.WHITEBOX,
619625
)
620626

621627
voc_labels = [
@@ -638,7 +644,6 @@ scores_dict = {i: score for i, score in zip(result_idxs, result_scores)}
638644
# Run explanation
639645
explanation = explainer(
640646
image,
641-
explain_mode=ExplainMode.WHITEBOX,
642647
label_names=voc_labels,
643648
targets=result_idxs, # target classes to explain
644649
)
@@ -657,6 +662,77 @@ explanation.save(
657662
) # image_name_aeroplane_conf_0.85.jpg
658663
```
659664

665+
## Measure quiality metrics of saliency maps
666+
667+
To compare different saliency maps, you can use the implemented quality metrics: Pointing Game, Insertion-Deletion AUC, and ADCC.
668+
669+
- **ADCC (Average Drop-Coherence-Complexity)** ([paper](https://arxiv.org/abs/2104.10252)/[impl](https://github.com/aimagelab/ADCC/)) - averages three submetrics:
670+
- **Average Drop** - The percentage drop in confidence when the model sees only the explanation map (image masked with the saliency map) instead of the full image.
671+
- **Coherence** - The coherency between the saliency map on the input image and saliency map on the explanation map (image masked with the saliency map). Requires generating an extra explanation (can be time-consuming for black box methods).
672+
- **Complexity** - Measures the L1 norm of the saliency map (average value per pixel). Fewer important pixels -> less complexity -> better saliency map.
673+
674+
- **Insertion-Deletion AUC** ([paper](https://arxiv.org/abs/1806.07421)) - Measures the AUC of the curve of model confidence when important pixels are sequentially inserted or deleted. Time-consuming, requires 60 model inferences: 30 steps of the insertion and deletion process.
675+
676+
- **Pointing Game** ([paper](https://arxiv.org/abs/1608.00507)/[impl](https://github.com/understandable-machine-intelligence-lab/Quantus/blob/main/quantus/metrics/localisation/pointing_game.py)) - Returns True if the most important saliency map pixel falls into the object ground truth bounding box. Requires ground truth annotation, so it is convenient to use on public datasets (COCO, VOC, ILSVRC) rather than individual images (check [accuracy_tests](../../tests/perf/test_accuracy.py) for examples).
677+
678+
679+
```python
680+
import cv2
681+
import numpy as np
682+
import openvino.runtime as ov
683+
from typing import Mapping
684+
685+
import openvino_xai as xai
686+
from openvino_xai.explainer import ExplainMode
687+
from openvino_xai.metrics import ADCC, InsertionDeletionAUC
688+
689+
690+
def preprocess_fn(image: np.ndarray) -> np.ndarray:
691+
"""Preprocess the input image."""
692+
x = cv2.resize(src=image, dsize=(224, 224))
693+
x = x.transpose((2, 0, 1))
694+
processed_image = np.expand_dims(x, 0)
695+
return processed_image
696+
697+
def postprocess_fn(output: Mapping):
698+
"""Postprocess the model output."""
699+
return softmax(output["logits"])
700+
701+
def softmax(x: np.ndarray) -> np.ndarray:
702+
"""Compute softmax values of x."""
703+
e_x = np.exp(x - np.max(x))
704+
return e_x / e_x.sum()
705+
706+
IMAGE_PATH = "path/to/image.jpg"
707+
MODEL_PATH = "path/to/model.xml"
708+
709+
image = cv2.imread(IMAGE_PATH)
710+
model = ov.Core().read_model(MODEL_PATH)
711+
712+
explainer = xai.Explainer(
713+
model,
714+
task=xai.Task.CLASSIFICATION,
715+
preprocess_fn=preprocess_fn,
716+
explain_mode=ExplainMode.WHITEBOX,
717+
explain_method=xai.Method.RECIPROCAM # Also VITRECIPROCAM, AISE, RISE, ACTIVATIONMAP are supported
718+
)
719+
720+
# Generate explanation (if several targets are passed, metrics for all saliency maps will be aggregated)
721+
explanation = explainer(image, targets=14, colormap=False, overlay=False, resize=True)
722+
723+
# Calculate InsertionDeletionAUC metric over the list of explanations and input images
724+
auc = InsertionDeletionAUC(model, preprocess_fn, postprocess_fn)
725+
auc_score = auc.evaluate([explanation], [image], steps=30) # {'insertion': 0.43, 'deletion': 0.09, 'delta': 0.34}
726+
insertion, deletion, delta = auc_score.values()
727+
print(f"Insertion {deletion:.2f}, Deletion {insertion:.2f}, Delta {delta:.2f}")
728+
729+
# Calculate ADCC metric over the list of explanations and input images
730+
adcc = ADCC(model, preprocess_fn, postprocess_fn, explainer)
731+
adcc_score = adcc.evaluate([explanation], [image]) # {'adcc': 0.95, 'coherency': 0.99, 'complexity': 0.13, 'average_drop': 0.0}
732+
adcc, coherency, complexity, average_drop = adcc_score.values()
733+
print(f"ADCC {adcc:.2f}, Coherency {coherency:.2f}, Complexity {complexity:.2f}, Average drop {average_drop:.2f}")
734+
```
735+
660736
## Example scripts
661737

662738
More usage scenarios that can be used with your own models and images as arguments are available in [examples](../../examples).

openvino_xai/metrics/adcc.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
class ADCC(BaseMetric):
1414
"""
15-
Implementation of the e Average Drop-Coherence-Complexity (ADCC) metric by Poppi, Samuele, et al 2021.
15+
Implementation of the Average Drop-Coherence-Complexity (ADCC) metric by Poppi, Samuele, et al 2021.
1616
1717
References:
18-
Poppi, Samuele, et al. "Revisiting the evaluation of class activation mapping for explainability:
19-
A novel metric and experimental analysis." Proceedings of the IEEE/CVF Conference on
20-
Computer Vision and Pattern Recognition. 2021.
18+
1) Poppi, Samuele, et al. "Revisiting the evaluation of class activation mapping for explainability:
19+
A novel metric and experimental analysis." Proceedings of the IEEE/CVF Conference on
20+
Computer Vision and Pattern Recognition. 2021.
21+
2) Reference implementation:
22+
https://github.com/aimagelab/ADCC/
2123
"""
2224

2325
def __init__(self, model, preprocess_fn, postprocess_fn, explainer=None, device_name="CPU"):
@@ -52,7 +54,8 @@ def average_drop(
5254

5355
def coherency(self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray) -> float:
5456
"""
55-
Measures the coherency of the saliency map. The explanation map (image masked with saliency map) should contain all the relevant features that explain a prediction and should remove useless features in a coherent way.
57+
Measures the coherency of the saliency map. The explanation map (image masked with saliency map) should
58+
contain all the relevant features that explain a prediction and should remove useless features in a coherent way.
5659
Saliency map and saliency map of exlanation map should be similar.
5760
The more the better.
5861
"""

openvino_xai/metrics/insertion_deletion_auc.py

+3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def __call__(
6262
:return: A dictionary containing the AUC scores for insertion and deletion scores.
6363
:rtype: Dict[str, float]
6464
"""
65+
66+
class_idx = np.argmax(self.model_predict(input_image)) if class_idx is None else class_idx
67+
6568
# Sort pixels by descending importance to find the most important pixels
6669
sorted_indices = np.argsort(-saliency_map.flatten())
6770
sorted_indices = np.unravel_index(sorted_indices, saliency_map.shape)

0 commit comments

Comments
 (0)