12
12
13
13
import openvino_xai as xai
14
14
from openvino_xai .common .utils import logger
15
- from openvino_xai .explainer .parameters import (
16
- ExplainMode ,
17
- ExplanationParameters ,
18
- TargetExplainGroup ,
19
- VisualizationParameters ,
20
- )
21
- from openvino_xai .inserter .parameters import ClassificationInsertionParameters
15
+ from openvino_xai .explainer .explain_group import TargetExplainGroup
16
+ from openvino_xai .explainer .explainer import ExplainMode
22
17
23
18
24
19
def get_argument_parser ():
@@ -61,14 +56,14 @@ def explain_auto(args):
61
56
62
57
# Prepare input image and explanation parameters, can be different for each explain call
63
58
image = cv2 .imread (args .image_path )
64
- explanation_parameters = ExplanationParameters (
59
+
60
+ # Generate explanation
61
+ explanation = explainer (
62
+ image ,
65
63
target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
66
64
target_explain_labels = [11 , 14 ], # target classes to explain
67
65
)
68
66
69
- # Generate explanation
70
- explanation = explainer (image , explanation_parameters )
71
-
72
67
logger .info (
73
68
f"explain_auto: Generated { len (explanation .saliency_map )} classification "
74
69
f"saliency maps of layout { explanation .layout } with shape { explanation .shape } ."
@@ -83,44 +78,39 @@ def explain_auto(args):
83
78
def explain_white_box (args ):
84
79
"""
85
80
Advanced use case using ExplainMode.WHITEBOX.
86
- insertion_parameters are provided to further configure the white-box method.
81
+ Insertion parameters (e.g. target_layer) are provided to further configure the white-box method (optional) .
87
82
"""
88
83
89
84
# Create ov.Model
90
85
model : ov .Model
91
86
model = ov .Core ().read_model (args .model_path )
92
87
93
- # Optional - define insertion parameters
94
- insertion_parameters = ClassificationInsertionParameters (
95
- # target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
96
- target_layer = "/backbone/conv/conv.2/Div" , # OTX mnet_v3
97
- # target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
98
- embed_scaling = True , # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model
99
- explain_method = xai .Method .RECIPROCAM , # ReciproCAM is the default XAI method for CNNs
100
- )
101
-
102
88
# Create explainer object
103
89
explainer = xai .Explainer (
104
90
model = model ,
105
91
task = xai .Task .CLASSIFICATION ,
106
92
preprocess_fn = preprocess_fn ,
107
93
explain_mode = ExplainMode .WHITEBOX , # defaults to AUTO
108
- insertion_parameters = insertion_parameters ,
94
+ explain_method = xai .Method .RECIPROCAM , # ReciproCAM is the default XAI method for CNNs
95
+ # target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
96
+ target_layer = "/backbone/conv/conv.2/Div" , # OTX mnet_v3
97
+ # target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
98
+ embed_scaling = True , # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model
109
99
)
110
100
111
101
# Prepare input image and explanation parameters, can be different for each explain call
112
102
image = cv2 .imread (args .image_path )
113
103
voc_labels = ['aeroplane' , 'bicycle' , 'bird' , 'boat' , 'bottle' , 'bus' , 'car' , 'cat' , 'chair' , 'cow' , 'diningtable' ,
114
104
'dog' , 'horse' , 'motorbike' , 'person' , 'pottedplant' , 'sheep' , 'sofa' , 'train' , 'tvmonitor' ]
115
- explanation_parameters = ExplanationParameters (
105
+
106
+ # Generate explanation
107
+ explanation = explainer (
108
+ image ,
116
109
target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
117
110
target_explain_labels = [11 , 14 ], # target classes to explain, also ['dog', 'person'] is a valid input
118
111
label_names = voc_labels , # optional names
119
- visualization_parameters = VisualizationParameters (overlay = True )
120
- )
121
-
122
- # Generate explanation
123
- explanation = explainer (image , explanation_parameters )
112
+ overlay = True ,
113
+ )
124
114
125
115
logger .info (
126
116
f"explain_white_box: Generated { len (explanation .saliency_map )} classification "
@@ -156,17 +146,14 @@ def explain_black_box(args):
156
146
image = cv2 .imread (args .image_path )
157
147
voc_labels = ['aeroplane' , 'bicycle' , 'bird' , 'boat' , 'bottle' , 'bus' , 'car' , 'cat' , 'chair' , 'cow' , 'diningtable' ,
158
148
'dog' , 'horse' , 'motorbike' , 'person' , 'pottedplant' , 'sheep' , 'sofa' , 'train' , 'tvmonitor' ]
159
- explanation_parameters = ExplanationParameters (
160
- target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
161
- target_explain_labels = ['dog' , 'person' ], # target classes to explain, also [11, 14] possible
162
- label_names = voc_labels , # optional names
163
- visualization_parameters = VisualizationParameters (overlay = True )
164
- )
165
149
166
150
# Generate explanation
167
151
explanation = explainer (
168
152
image ,
169
- explanation_parameters ,
153
+ target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
154
+ target_explain_labels = ['dog' , 'person' ], # target classes to explain, also [11, 14] possible
155
+ label_names = voc_labels , # optional names
156
+ overlay = True ,
170
157
num_masks = 1000 , # kwargs of the RISE algo
171
158
)
172
159
@@ -197,11 +184,6 @@ def explain_white_box_multiple_images(args):
197
184
preprocess_fn = preprocess_fn ,
198
185
)
199
186
200
- explanation_parameters = ExplanationParameters (
201
- target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
202
- target_explain_labels = [14 ], # target classes to explain
203
- )
204
-
205
187
# Create list of images
206
188
img_data_formats = (".jpg" , ".jpeg" , ".gif" , ".bmp" , ".tif" , ".tiff" , ".png" )
207
189
if args .image_path .lower ().endswith (img_data_formats ):
@@ -216,7 +198,7 @@ def explain_white_box_multiple_images(args):
216
198
217
199
# Generate explanation
218
200
images = [cv2 .imread (image_path ) for image_path in img_files ]
219
- explanation = [explainer (image , explanation_parameters ) for image in images ]
201
+ explanation = [explainer (image , target_explain_group = TargetExplainGroup . CUSTOM , target_explain_labels = [ 14 ] ) for image in images ]
220
202
221
203
logger .info (
222
204
f"explain_white_box_multiple_images: Generated { len (explanation )} explanations "
@@ -236,32 +218,27 @@ def explain_white_box_vit(args):
236
218
model : ov .Model
237
219
model = ov .Core ().read_model (args .model_path )
238
220
239
- # Optional - define insertion parameters
240
- insertion_parameters = ClassificationInsertionParameters (
241
- # target_layer="/layers.10/ffn/Add", # OTX deit-tiny
242
- # target_layer="/blocks/blocks.10/Add_1", # timm vit_base_patch8_224.augreg_in21k_ft_in1k
243
- explain_method = xai .Method .VITRECIPROCAM ,
244
- )
245
-
246
221
# Create explainer object
247
222
explainer = xai .Explainer (
248
223
model = model ,
249
224
task = xai .Task .CLASSIFICATION ,
250
225
preprocess_fn = preprocess_fn ,
251
226
explain_mode = ExplainMode .WHITEBOX , # defaults to AUTO
252
- insertion_parameters = insertion_parameters ,
227
+ explain_method = xai .Method .VITRECIPROCAM ,
228
+ # target_layer="/layers.10/ffn/Add", # OTX deit-tiny
229
+ # target_layer="/blocks/blocks.10/Add_1", # timm vit_base_patch8_224.augreg_in21k_ft_in1k
253
230
)
254
231
255
232
# Prepare input image and explanation parameters, can be different for each explain call
256
233
image = cv2 .imread (args .image_path )
257
- explanation_parameters = ExplanationParameters (
234
+
235
+ # Generate explanation
236
+ explanation = explainer (
237
+ image ,
258
238
target_explain_group = TargetExplainGroup .CUSTOM , # CUSTOM list of classes to explain, also ALL possible
259
239
target_explain_labels = [0 , 1 , 2 , 3 , 4 ], # target classes to explain
260
240
)
261
241
262
- # Generate explanation
263
- explanation = explainer (image , explanation_parameters )
264
-
265
242
logger .info (
266
243
f"explain_white_box_vit: Generated { len (explanation .saliency_map )} classification "
267
244
f"saliency maps of layout { explanation .layout } with shape { explanation .shape } ."
@@ -298,26 +275,21 @@ def insert_xai(args):
298
275
def insert_xai_w_params (args ):
299
276
"""
300
277
White-box scenario.
301
- Insertion of the XAI branch into the IR with insertion parameters, thus IR has additional 'saliency_map' output.
278
+ Insertion of the XAI branch into the IR with insertion parameters (e.g. target_layer) , thus, IR has additional 'saliency_map' output.
302
279
"""
303
280
304
281
# Create ov.Model
305
282
model : ov .Model
306
283
model = ov .Core ().read_model (args .model_path )
307
284
308
- # Define insertion parameters
309
- insertion_parameters = ClassificationInsertionParameters (
310
- target_layer = "/backbone/conv/conv.2/Div" , # OTX mnet_v3
311
- # target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
312
- embed_scaling = True ,
313
- explain_method = xai .Method .RECIPROCAM ,
314
- )
315
-
316
285
# insert XAI branch
317
286
model_xai = xai .insert_xai (
318
287
model ,
319
288
task = xai .Task .CLASSIFICATION ,
320
- insertion_parameters = insertion_parameters ,
289
+ explain_method = xai .Method .RECIPROCAM ,
290
+ target_layer = "/backbone/conv/conv.2/Div" , # OTX mnet_v3
291
+ # target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
292
+ embed_scaling = True ,
321
293
)
322
294
323
295
logger .info ("insert_xai_w_params: XAI branch inserted into IR with parameters." )
0 commit comments