|
| 1 | +from typing import Any, Dict, List |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from scipy import stats as STS |
| 5 | + |
| 6 | +from openvino_xai import Task |
| 7 | +from openvino_xai.explainer.explainer import Explainer, ExplainMode |
| 8 | +from openvino_xai.explainer.explanation import Explanation |
| 9 | +from openvino_xai.metrics.base import BaseMetric |
| 10 | + |
| 11 | + |
| 12 | +class ADCC(BaseMetric): |
| 13 | + """ |
| 14 | + Implementation of the e Average Drop-Coherence-Complexity (ADCC) metric by Poppi, Samuele, et al 2021. |
| 15 | +
|
| 16 | + References: |
| 17 | + Poppi, Samuele, et al. "Revisiting the evaluation of class activation mapping for explainability: |
| 18 | + A novel metric and experimental analysis." Proceedings of the IEEE/CVF Conference on |
| 19 | + Computer Vision and Pattern Recognition. 2021. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, model, preprocess_fn, postprocess_fn, explainer=None, device_name="CPU"): |
| 23 | + super().__init__( |
| 24 | + model=model, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, device_name=device_name |
| 25 | + ) |
| 26 | + if explainer is None: |
| 27 | + self.explainer = Explainer( |
| 28 | + model=model, |
| 29 | + task=Task.CLASSIFICATION, |
| 30 | + preprocess_fn=self.preprocess_fn, |
| 31 | + explain_mode=ExplainMode.WHITEBOX, |
| 32 | + ) |
| 33 | + else: |
| 34 | + self.explainer = explainer |
| 35 | + |
| 36 | + def average_drop( |
| 37 | + self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray, model_output: np.ndarray |
| 38 | + ) -> float: |
| 39 | + """ |
| 40 | + Measures the average percentage drop in confidence for the target class when the model sees only the |
| 41 | + explanation map (image masked with saliency map), instead of the full image. |
| 42 | + The less the better. |
| 43 | + """ |
| 44 | + confidence_on_input = np.max(model_output) |
| 45 | + |
| 46 | + masked_image = (image * saliency_map[:, :, None]).astype(np.uint8) |
| 47 | + prediction_on_saliency_map = self.model_predict(masked_image) |
| 48 | + confidence_on_saliency_map = prediction_on_saliency_map[class_idx] |
| 49 | + |
| 50 | + return max(0.0, confidence_on_input - confidence_on_saliency_map) / confidence_on_input |
| 51 | + |
| 52 | + def coherency(self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray) -> float: |
| 53 | + """ |
| 54 | + Measures the coherency of the saliency map. The explanation map (image masked with saliency map) should contain all the relevant features that explain a prediction and should remove useless features in a coherent way. |
| 55 | + Saliency map and saliency map of exlanation map should be similar. |
| 56 | + The more the better. |
| 57 | + """ |
| 58 | + |
| 59 | + masked_image = image * saliency_map[:, :, None] |
| 60 | + saliency_map_mapped_image = self.explainer(masked_image, targets=[class_idx], colormap=False, scaling=False) |
| 61 | + saliency_map_mapped_image = saliency_map_mapped_image.saliency_map[class_idx] |
| 62 | + |
| 63 | + A, B = saliency_map, saliency_map_mapped_image |
| 64 | + # Pearson correlation coefficient |
| 65 | + Asq, Bsq = A.flatten(), B.flatten() |
| 66 | + y, _ = STS.pearsonr(Asq, Bsq) |
| 67 | + y = (y + 1) / 2 |
| 68 | + |
| 69 | + return y |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def complexity(saliency_map: np.ndarray) -> float: |
| 73 | + """ |
| 74 | + Measures the complexity of the saliency map. Less important pixels -> less complexity. |
| 75 | + Defined as L1 norm of the saliency map. |
| 76 | + The less the better. |
| 77 | + """ |
| 78 | + return abs(saliency_map).sum() / (saliency_map.shape[-1] * saliency_map.shape[-2]) |
| 79 | + |
| 80 | + def __call__(self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray) -> Dict[str, float]: |
| 81 | + """ |
| 82 | + Calculate the ADCC metric for a given saliency map and class index. |
| 83 | + The more the better. |
| 84 | +
|
| 85 | + Parameters: |
| 86 | + :param saliency_map: Saliency map for class_idx class (H, W). |
| 87 | + :type saliency_map: np.ndarray |
| 88 | + :param class_idx: The class index of saliency map. |
| 89 | + :type class_idx: int |
| 90 | + :param input_image: The input image to the model (H, W, C). |
| 91 | + :type input_image: np.ndarray |
| 92 | +
|
| 93 | + Returns: |
| 94 | + :return: A dictionary containing the ADCC, coherency, complexity, and average drop metrics. |
| 95 | + :rtype: Dict[str, float] |
| 96 | + """ |
| 97 | + if not (0 <= np.min(saliency_map) and np.max(saliency_map) <= 1): |
| 98 | + # Scale saliency map to [0, 1] |
| 99 | + saliency_map = saliency_map / 255 |
| 100 | + |
| 101 | + model_output = self.model_predict(input_image) |
| 102 | + |
| 103 | + avgdrop = self.average_drop(saliency_map, class_idx, input_image, model_output) |
| 104 | + coh = self.coherency(saliency_map, class_idx, input_image) |
| 105 | + com = self.complexity(saliency_map) |
| 106 | + |
| 107 | + adcc = 3 / (1 / coh + 1 / (1 - com) + 1 / (1 - avgdrop)) |
| 108 | + return {"adcc": adcc, "coherency": coh, "complexity": com, "average_drop": avgdrop} |
| 109 | + |
| 110 | + def evaluate( |
| 111 | + self, explanations: List[Explanation], input_images: List[np.ndarray], **kwargs: Any |
| 112 | + ) -> Dict[str, float]: |
| 113 | + """ |
| 114 | + Evaluate the ADCC metric over a list of explanations and input images. |
| 115 | +
|
| 116 | + Parameters: |
| 117 | + :param explanations: A list of explanations for each image. |
| 118 | + :type explanations: List[Explanation] |
| 119 | + :param input_images: A list of input images. |
| 120 | + :type input_images: List[np.ndarray] |
| 121 | +
|
| 122 | + Returns: |
| 123 | + :return: A dictionary containing the average ADCC score. |
| 124 | + :rtype: Dict[str, float] |
| 125 | + """ |
| 126 | + results = [] |
| 127 | + for input_image, explanation in zip(input_images, explanations): |
| 128 | + for class_idx, saliency_map in explanation.saliency_map.items(): |
| 129 | + metric_dict = self(saliency_map, int(class_idx), input_image) |
| 130 | + results.append(metric_dict["adcc"]) |
| 131 | + adcc = np.mean(np.array(results), axis=0) |
| 132 | + return {"adcc": adcc} |
0 commit comments