Skip to content

Commit 160fdfa

Browse files
author
Evgeny Tsykunov
authored
Flatten parameter objects (#16)
* Flatten explanation and visualization paremeter objects * Remove leftover * Fix tests * Dicstrings and interface * explain parameters -> mode * Flatten incertion params * Decompose method creation in explainer * Fix code quality * Fix comments + refactor visualizer * make visualizer stateless * black + isort * minor leftover
1 parent c166de3 commit 160fdfa

28 files changed

+604
-706
lines changed

examples/run_classification.py

+35-63
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,8 @@
1212

1313
import openvino_xai as xai
1414
from openvino_xai.common.utils import logger
15-
from openvino_xai.explainer.parameters import (
16-
ExplainMode,
17-
ExplanationParameters,
18-
TargetExplainGroup,
19-
VisualizationParameters,
20-
)
21-
from openvino_xai.inserter.parameters import ClassificationInsertionParameters
15+
from openvino_xai.explainer.explain_group import TargetExplainGroup
16+
from openvino_xai.explainer.explainer import ExplainMode
2217

2318

2419
def get_argument_parser():
@@ -61,14 +56,14 @@ def explain_auto(args):
6156

6257
# Prepare input image and explanation parameters, can be different for each explain call
6358
image = cv2.imread(args.image_path)
64-
explanation_parameters = ExplanationParameters(
59+
60+
# Generate explanation
61+
explanation = explainer(
62+
image,
6563
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
6664
target_explain_labels=[11, 14], # target classes to explain
6765
)
6866

69-
# Generate explanation
70-
explanation = explainer(image, explanation_parameters)
71-
7267
logger.info(
7368
f"explain_auto: Generated {len(explanation.saliency_map)} classification "
7469
f"saliency maps of layout {explanation.layout} with shape {explanation.shape}."
@@ -83,44 +78,39 @@ def explain_auto(args):
8378
def explain_white_box(args):
8479
"""
8580
Advanced use case using ExplainMode.WHITEBOX.
86-
insertion_parameters are provided to further configure the white-box method.
81+
Insertion parameters (e.g. target_layer) are provided to further configure the white-box method (optional).
8782
"""
8883

8984
# Create ov.Model
9085
model: ov.Model
9186
model = ov.Core().read_model(args.model_path)
9287

93-
# Optional - define insertion parameters
94-
insertion_parameters = ClassificationInsertionParameters(
95-
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
96-
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
97-
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
98-
embed_scaling=True, # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model
99-
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
100-
)
101-
10288
# Create explainer object
10389
explainer = xai.Explainer(
10490
model=model,
10591
task=xai.Task.CLASSIFICATION,
10692
preprocess_fn=preprocess_fn,
10793
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
108-
insertion_parameters=insertion_parameters,
94+
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
95+
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
96+
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
97+
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
98+
embed_scaling=True, # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model
10999
)
110100

111101
# Prepare input image and explanation parameters, can be different for each explain call
112102
image = cv2.imread(args.image_path)
113103
voc_labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
114104
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
115-
explanation_parameters = ExplanationParameters(
105+
106+
# Generate explanation
107+
explanation = explainer(
108+
image,
116109
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
117110
target_explain_labels=[11, 14], # target classes to explain, also ['dog', 'person'] is a valid input
118111
label_names=voc_labels, # optional names
119-
visualization_parameters=VisualizationParameters(overlay=True)
120-
)
121-
122-
# Generate explanation
123-
explanation = explainer(image, explanation_parameters)
112+
overlay=True,
113+
)
124114

125115
logger.info(
126116
f"explain_white_box: Generated {len(explanation.saliency_map)} classification "
@@ -156,17 +146,14 @@ def explain_black_box(args):
156146
image = cv2.imread(args.image_path)
157147
voc_labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
158148
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
159-
explanation_parameters = ExplanationParameters(
160-
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
161-
target_explain_labels=['dog', 'person'], # target classes to explain, also [11, 14] possible
162-
label_names=voc_labels, # optional names
163-
visualization_parameters=VisualizationParameters(overlay=True)
164-
)
165149

166150
# Generate explanation
167151
explanation = explainer(
168152
image,
169-
explanation_parameters,
153+
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
154+
target_explain_labels=['dog', 'person'], # target classes to explain, also [11, 14] possible
155+
label_names=voc_labels, # optional names
156+
overlay=True,
170157
num_masks=1000, # kwargs of the RISE algo
171158
)
172159

@@ -197,11 +184,6 @@ def explain_white_box_multiple_images(args):
197184
preprocess_fn=preprocess_fn,
198185
)
199186

200-
explanation_parameters = ExplanationParameters(
201-
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
202-
target_explain_labels=[14], # target classes to explain
203-
)
204-
205187
# Create list of images
206188
img_data_formats = (".jpg", ".jpeg", ".gif", ".bmp", ".tif", ".tiff", ".png")
207189
if args.image_path.lower().endswith(img_data_formats):
@@ -216,7 +198,7 @@ def explain_white_box_multiple_images(args):
216198

217199
# Generate explanation
218200
images = [cv2.imread(image_path) for image_path in img_files]
219-
explanation = [explainer(image, explanation_parameters) for image in images]
201+
explanation = [explainer(image, target_explain_group=TargetExplainGroup.CUSTOM, target_explain_labels=[14]) for image in images]
220202

221203
logger.info(
222204
f"explain_white_box_multiple_images: Generated {len(explanation)} explanations "
@@ -236,32 +218,27 @@ def explain_white_box_vit(args):
236218
model: ov.Model
237219
model = ov.Core().read_model(args.model_path)
238220

239-
# Optional - define insertion parameters
240-
insertion_parameters = ClassificationInsertionParameters(
241-
# target_layer="/layers.10/ffn/Add", # OTX deit-tiny
242-
# target_layer="/blocks/blocks.10/Add_1", # timm vit_base_patch8_224.augreg_in21k_ft_in1k
243-
explain_method=xai.Method.VITRECIPROCAM,
244-
)
245-
246221
# Create explainer object
247222
explainer = xai.Explainer(
248223
model=model,
249224
task=xai.Task.CLASSIFICATION,
250225
preprocess_fn=preprocess_fn,
251226
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
252-
insertion_parameters=insertion_parameters,
227+
explain_method=xai.Method.VITRECIPROCAM,
228+
# target_layer="/layers.10/ffn/Add", # OTX deit-tiny
229+
# target_layer="/blocks/blocks.10/Add_1", # timm vit_base_patch8_224.augreg_in21k_ft_in1k
253230
)
254231

255232
# Prepare input image and explanation parameters, can be different for each explain call
256233
image = cv2.imread(args.image_path)
257-
explanation_parameters = ExplanationParameters(
234+
235+
# Generate explanation
236+
explanation = explainer(
237+
image,
258238
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
259239
target_explain_labels=[0, 1, 2, 3, 4], # target classes to explain
260240
)
261241

262-
# Generate explanation
263-
explanation = explainer(image, explanation_parameters)
264-
265242
logger.info(
266243
f"explain_white_box_vit: Generated {len(explanation.saliency_map)} classification "
267244
f"saliency maps of layout {explanation.layout} with shape {explanation.shape}."
@@ -298,26 +275,21 @@ def insert_xai(args):
298275
def insert_xai_w_params(args):
299276
"""
300277
White-box scenario.
301-
Insertion of the XAI branch into the IR with insertion parameters, thus IR has additional 'saliency_map' output.
278+
Insertion of the XAI branch into the IR with insertion parameters (e.g. target_layer), thus, IR has additional 'saliency_map' output.
302279
"""
303280

304281
# Create ov.Model
305282
model: ov.Model
306283
model = ov.Core().read_model(args.model_path)
307284

308-
# Define insertion parameters
309-
insertion_parameters = ClassificationInsertionParameters(
310-
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
311-
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
312-
embed_scaling=True,
313-
explain_method=xai.Method.RECIPROCAM,
314-
)
315-
316285
# insert XAI branch
317286
model_xai = xai.insert_xai(
318287
model,
319288
task=xai.Task.CLASSIFICATION,
320-
insertion_parameters=insertion_parameters,
289+
explain_method=xai.Method.RECIPROCAM,
290+
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
291+
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
292+
embed_scaling=True,
321293
)
322294

323295
logger.info("insert_xai_w_params: XAI branch inserted into IR with parameters.")

examples/run_detection.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111

1212
import openvino_xai as xai
1313
from openvino_xai.common.utils import logger
14-
from openvino_xai.explainer.parameters import (
15-
ExplainMode,
16-
ExplanationParameters,
17-
TargetExplainGroup,
18-
)
19-
from openvino_xai.inserter.parameters import DetectionInsertionParameters
14+
from openvino_xai.explainer.explain_group import TargetExplainGroup
15+
from openvino_xai.explainer.explainer import ExplainMode
2016

2117

2218
def get_argument_parser():
@@ -62,32 +58,27 @@ def main(argv):
6258
# "/bbox_head/atss_cls_3/Conv/WithoutBiases",
6359
# "/bbox_head/atss_cls_4/Conv/WithoutBiases",
6460
# ]
65-
insertion_parameters = DetectionInsertionParameters(
66-
target_layer=cls_head_output_node_names,
67-
# num_anchors=[1, 1, 1, 1, 1],
68-
saliency_map_size=(23, 23), # Optional
69-
explain_method=xai.Method.DETCLASSPROBABILITYMAP, # Optional
70-
)
7161

7262
# Create explainer object
7363
explainer = xai.Explainer(
7464
model=model,
7565
task=xai.Task.DETECTION,
7666
preprocess_fn=preprocess_fn,
7767
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
78-
insertion_parameters=insertion_parameters,
68+
target_layer=cls_head_output_node_names,
69+
saliency_map_size=(23, 23), # Optional
7970
)
8071

8172
# Prepare input image and explanation parameters, can be different for each explain call
8273
image = cv2.imread(args.image_path)
83-
explanation_parameters = ExplanationParameters(
74+
75+
# Generate explanation
76+
explanation = explainer(
77+
image,
8478
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
8579
target_explain_labels=[0, 1, 2, 3, 4], # target classes to explain
8680
)
8781

88-
# Generate explanation
89-
explanation = explainer(image, explanation_parameters)
90-
9182
logger.info(
9283
f"Generated {len(explanation.saliency_map)} detection "
9384
f"saliency maps of layout {explanation.layout} with shape {explanation.shape}."

notebooks/xai_saliency_map_interpretation/xai_saliency_map_interpretation.ipynb

-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@
7777
"import openvino_xai as xai\n",
7878
"\n",
7979
"from openvino_xai.common.utils import retrieve_otx_model\n",
80-
"from openvino_xai.explainer.parameters import (\n",
81-
" ExplanationParameters, VisualizationParameters)\n",
8280
"from openvino_xai.explainer.utils import ActivationType"
8381
]
8482
},

openvino_xai/api/api.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
# Copyright (C) 2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from typing import List
5+
46
import openvino.runtime as ov
57

6-
from openvino_xai.common.parameters import Task
8+
from openvino_xai.common.parameters import Method, Task
79
from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger
8-
from openvino_xai.inserter.parameters import InsertionParameters
910
from openvino_xai.methods.factory import WhiteBoxMethodFactory
1011

1112

1213
def insert_xai(
1314
model: ov.Model,
1415
task: Task,
15-
insertion_parameters: InsertionParameters | None = None,
16+
explain_method: Method | None = None,
17+
target_layer: str | List[str] | None = None,
18+
embed_scaling: bool | None = True,
19+
**kwargs,
1620
) -> ov.Model:
1721
"""
1822
Function that inserts XAI branch into IR.
@@ -24,10 +28,12 @@ def insert_xai(
2428
:type model: ov.Model | str
2529
:param task: Type of the task: CLASSIFICATION or DETECTION.
2630
:type task: Task
27-
:param insertion_parameters: Insertion parameters that parametrize white-box method,
28-
that will be inserted into the model graph (optional).
29-
:type insertion_parameters: InsertionParameters
30-
:return: IR with XAI branch.
31+
:parameter explain_method: Explain method to use for model explanation.
32+
:type explain_method: Method
33+
:parameter target_layer: Target layer(s) (node(s)) name after which the XAI branch will be inserted.
34+
:type target_layer: str | List[str]
35+
:parameter embed_scaling: If set to True, saliency map scale (0 ~ 255) operation is embedded in the model.
36+
:type embed_scaling: bool
3137
"""
3238

3339
if has_xai(model):
@@ -38,8 +44,11 @@ def insert_xai(
3844
task=task,
3945
model=model,
4046
preprocess_fn=IdentityPreprocessFN(),
41-
insertion_parameters=insertion_parameters,
47+
explain_method=explain_method,
48+
target_layer=target_layer,
49+
embed_scaling=embed_scaling,
4250
prepare_model=False,
51+
**kwargs,
4352
)
4453

4554
model_xai = method.prepare_model(load_model=False)

openvino_xai/explainer/__init__.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,15 @@
33
"""
44
Interface for getting explanation.
55
"""
6-
from openvino_xai.explainer.explainer import Explainer
6+
from openvino_xai.explainer.explain_group import TargetExplainGroup
7+
from openvino_xai.explainer.explainer import Explainer, ExplainMode
78
from openvino_xai.explainer.explanation import Explanation, Layout
8-
from openvino_xai.explainer.parameters import (
9-
ExplainMode,
10-
ExplanationParameters,
11-
TargetExplainGroup,
12-
VisualizationParameters,
13-
)
149
from openvino_xai.explainer.visualizer import Visualizer, colormap, overlay, resize
1510

1611
__all__ = [
1712
"Explainer",
1813
"ExplainMode",
1914
"TargetExplainGroup",
20-
"VisualizationParameters",
21-
"ExplanationParameters",
2215
"Layout",
2316
"Explanation",
2417
"Visualizer",
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (C) 2023-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from enum import Enum
5+
6+
7+
class TargetExplainGroup(Enum):
8+
"""
9+
Enum describes different target explanation groups.
10+
11+
Contains the following values:
12+
IMAGE - Global (single) saliency map per image.
13+
ALL - Saliency map per each possible target.
14+
PREDICTIONS - Saliency map per each prediction (prediction = target).
15+
CUSTOM - Saliency map per each specified target.
16+
"""
17+
18+
IMAGE = "image"
19+
ALL = "all"
20+
CUSTOM = "custom"

0 commit comments

Comments
 (0)