Skip to content

Commit 10a0d96

Browse files
author
Evgeny Tsykunov
authored
Fix bugs (#69)
* adaptive font_size * update params * logging * font_face -> 2 * adapt text_height * define offset for text * gray cmap * improve gray cmap * fix detection overlay text * fix bhwc * fix save * minor * test plot * fix fp rounding error
1 parent 41ddb11 commit 10a0d96

File tree

6 files changed

+73
-16
lines changed

6 files changed

+73
-16
lines changed

openvino_xai/explainer/explainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class ExplainMode(Enum):
3434
Contains the following values:
3535
WHITEBOX - The model is explained in white box mode, i.e. XAI branch is getting inserted into the model graph.
3636
BLACKBOX - The model is explained in black box model.
37+
AUTO - The model is explained in the white-box mode first, if fails - black-box mode will run.
3738
"""
3839

3940
WHITEBOX = "whitebox"

openvino_xai/explainer/explanation.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def save(
188188
map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR)
189189
if isinstance(target_idx, str):
190190
target_name = "activation_map"
191-
elif self.label_names and isinstance(target_idx, np.int64) and self.task != Task.DETECTION:
191+
elif self.label_names and isinstance(target_idx, (int, np.int64)) and self.task != Task.DETECTION:
192192
target_name = self.label_names[target_idx]
193193
else:
194194
target_name = str(target_idx)
@@ -261,7 +261,12 @@ def _plot_matplotlib(self, checked_targets: list[int | str], num_cols: int) -> N
261261

262262
map_to_plot = self.saliency_map[target_index]
263263

264-
axes[i].imshow(map_to_plot)
264+
if map_to_plot.ndim == 3:
265+
axes[i].imshow(map_to_plot)
266+
elif map_to_plot.ndim == 2:
267+
axes[i].imshow(map_to_plot, cmap="gray")
268+
else:
269+
raise ValueError(f"Saliency map expected to be 3 or 2-dimensional, but got {map_to_plot.ndim}.")
265270
axes[i].axis("off") # Hide the axis
266271
axes[i].set_title(f"Class {label_name}")
267272

openvino_xai/explainer/visualizer.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -174,33 +174,34 @@ def visualize(
174174
# Convert back to dict
175175
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)
176176

177-
@staticmethod
178177
def _put_classification_info(
178+
self,
179179
saliency_map_np: np.ndarray,
180180
indices: List[int],
181181
label_names: List[str] | None,
182182
predictions: Dict[int, Prediction] | None,
183183
) -> None:
184-
corner_location = 3, 17
184+
offset = 3
185185
for smap, target_index in zip(range(len(saliency_map_np)), indices):
186186
label = label_names[target_index] if label_names else str(target_index)
187187
if predictions and target_index in predictions:
188188
score = predictions[target_index].score
189189
if score:
190190
label = f"{label}|{score:.2f}"
191191

192+
font_scale, text_height = self._fit_text_to_image(label, offset, saliency_map_np[smap].shape[1])
192193
cv2.putText(
193194
saliency_map_np[smap],
194195
label,
195-
org=corner_location,
196-
fontFace=1,
197-
fontScale=1.3,
196+
org=(offset, text_height + offset),
197+
fontFace=2,
198+
fontScale=font_scale,
198199
color=(255, 0, 0),
199-
thickness=2,
200+
thickness=1,
200201
)
201202

202-
@staticmethod
203203
def _put_detection_info(
204+
self,
204205
saliency_map_np: np.ndarray,
205206
indices: List[int],
206207
label_names: List[str] | None,
@@ -209,6 +210,7 @@ def _put_detection_info(
209210
if not predictions:
210211
return
211212

213+
offset = 7
212214
for smap, target_index in zip(range(len(saliency_map_np)), indices):
213215
saliency_map = saliency_map_np[smap]
214216
label_index = predictions[target_index].label
@@ -220,17 +222,40 @@ def _put_detection_info(
220222

221223
label = label_names[label_index] if label_names else label_index
222224
label_score = f"{label}|{score:.2f}"
223-
box_location = int(x1), int(y1 - 5)
225+
226+
font_scale, _ = self._fit_text_to_image(label_score, x1, saliency_map.shape[1])
227+
box_location = x1, y1 - offset
224228
cv2.putText(
225229
saliency_map,
226230
label_score,
227231
org=box_location,
228-
fontFace=1,
229-
fontScale=1.3,
232+
fontFace=2,
233+
fontScale=font_scale,
230234
color=(255, 0, 0),
231-
thickness=2,
235+
thickness=1,
232236
)
233237

238+
@staticmethod
239+
def _fit_text_to_image(
240+
text: str,
241+
x_start: int,
242+
image_width: int,
243+
font_scale: float = 1.0,
244+
thickness: int = 1,
245+
) -> Tuple[float, int]:
246+
font_face = 2
247+
max_width = image_width - 5
248+
while True:
249+
text_size, _ = cv2.getTextSize(text, font_face, font_scale, thickness)
250+
text_width, text_height = text_size
251+
252+
if x_start + text_width <= max_width:
253+
return font_scale, text_height
254+
255+
font_scale -= 0.1
256+
if abs(font_scale - 0.1) < 0.001:
257+
return font_scale, text_height
258+
234259
@staticmethod
235260
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
236261
if explanation.layout not in GRAY_LAYOUTS:

openvino_xai/methods/black_box/aise/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import openvino.runtime as ov
1111
from scipy.optimize import direct
1212

13-
from openvino_xai.common.utils import IdentityPreprocessFN
13+
from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout
1414
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod
1515

1616

@@ -92,6 +92,8 @@ def _objective_function(self, args) -> float:
9292

9393
kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params)
9494
kernel_mask = np.clip(kernel_mask, 0, 1)
95+
if is_bhwc_layout(self.data_preprocessed):
96+
kernel_mask = np.expand_dims(kernel_mask, 2)
9597

9698
pred_loss_preserve = 0.0
9799
if self.preservation:

tests/unit/explainer/test_explanation.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
from tests.unit.explainer.test_explanation_utils import VOC_NAMES
1313

1414
SALIENCY_MAPS = (np.random.rand(1, 20, 5, 5) * 255).astype(np.uint8)
15+
SALIENCY_MAPS_DICT = {
16+
0: (np.random.rand(5, 5, 3) * 255).astype(np.uint8),
17+
2: (np.random.rand(5, 5, 3) * 255).astype(np.uint8),
18+
}
19+
SALIENCY_MAPS_DICT_EXCEPTION = {
20+
0: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8),
21+
2: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8),
22+
}
1523
SALIENCY_MAPS_IMAGE = (np.random.rand(1, 5, 5) * 255).astype(np.uint8)
1624

1725

@@ -106,7 +114,7 @@ def test_plot(self, mocker, caplog):
106114
# Update the num columns for the matplotlib visualization grid
107115
explanation.plot(backend="matplotlib", num_columns=1)
108116

109-
# Class index that is not in saliency maps will be ommitted with message
117+
# Class index that is not in saliency maps will be omitted with message
110118
with caplog.at_level(logging.INFO):
111119
explanation.plot([0, 3], backend="matplotlib")
112120
assert "Provided class index 3 is not available among saliency maps." in caplog.text
@@ -123,3 +131,13 @@ def test_plot(self, mocker, caplog):
123131
# Plot activation map
124132
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None)
125133
explanation.plot()
134+
135+
# Plot colored map
136+
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT, label_names=None)
137+
explanation.plot()
138+
139+
# Plot wrong map shape
140+
with pytest.raises(Exception) as exc_info:
141+
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT_EXCEPTION, label_names=None)
142+
explanation.plot()
143+
assert str(exc_info.value) == "Saliency map expected to be 3 or 2-dimensional, but got 4."

tests/unit/explainer/test_visualization.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from openvino_xai.explainer.visualizer import Visualizer, colormap, overlay, resize
1111
from openvino_xai.methods.base import Prediction
1212

13+
ORIGINAL_INPUT_IMAGE = [
14+
np.ones((100, 100, 3)),
15+
np.ones((10, 10, 3)),
16+
]
17+
1318
SALIENCY_MAPS = [
1419
(np.random.rand(1, 5, 5) * 255).astype(np.uint8),
1520
(np.random.rand(1, 2, 5, 5) * 255).astype(np.uint8),
@@ -97,6 +102,7 @@ def test_overlay():
97102

98103

99104
class TestVisualizer:
105+
@pytest.mark.parametrize("original_input_image", ORIGINAL_INPUT_IMAGE)
100106
@pytest.mark.parametrize("saliency_maps", SALIENCY_MAPS)
101107
@pytest.mark.parametrize("explain_all_classes", EXPLAIN_ALL_CLASSES)
102108
@pytest.mark.parametrize("task", [Task.CLASSIFICATION, Task.DETECTION])
@@ -107,6 +113,7 @@ class TestVisualizer:
107113
@pytest.mark.parametrize("overlay_weight", [0.5, 0.3])
108114
def test_visualizer(
109115
self,
116+
original_input_image,
110117
saliency_maps,
111118
explain_all_classes,
112119
task,
@@ -124,7 +131,6 @@ def test_visualizer(
124131
explanation = Explanation(saliency_maps, targets=explain_targets, task=Task.CLASSIFICATION)
125132

126133
raw_sal_map_dims = len(explanation.shape)
127-
original_input_image = np.ones((20, 20, 3))
128134
visualizer = Visualizer()
129135
explanation = visualizer(
130136
explanation=explanation,

0 commit comments

Comments
 (0)