@@ -144,6 +144,12 @@ class TestImageClassificationTimm:
144
144
21843 : 2441 , # 2441 is a cheetah class_id in the ImageNet-21k dataset
145
145
11821 : 1652 , # 1652 is a cheetah class_id in the ImageNet-12k dataset
146
146
}
147
+ reference_maps_names = {
148
+ (ExplainMode .WHITEBOX , Method .RECIPROCAM ): Path ("resnet18.a1_in1k_reciprocam.npy" ),
149
+ (ExplainMode .WHITEBOX , Method .ACTIVATIONMAP ): Path ("resnet18.a1_in1k_activationmap.npy" ),
150
+ (ExplainMode .BLACKBOX , Method .AISE ): Path ("resnet18.a1_in1k_aise.npy" ),
151
+ (ExplainMode .BLACKBOX , Method .RISE ): Path ("resnet18.a1_in1k_rise.npy" ),
152
+ }
147
153
148
154
@pytest .fixture (autouse = True )
149
155
def setup (self , fxt_data_root , fxt_output_root , fxt_clear_cache ):
@@ -504,6 +510,62 @@ def test_torch_insert_xai_with_layer(self, model_id: str, detect: str):
504
510
505
511
self .clear_cache ()
506
512
513
+ @pytest .mark .parametrize (
514
+ "explain_mode, explain_method" ,
515
+ [
516
+ (ExplainMode .WHITEBOX , Method .RECIPROCAM ),
517
+ (ExplainMode .WHITEBOX , Method .ACTIVATIONMAP ),
518
+ (ExplainMode .BLACKBOX , Method .AISE ),
519
+ (ExplainMode .BLACKBOX , Method .RISE ),
520
+ ],
521
+ )
522
+ def test_reference_map (self , explain_mode , explain_method ):
523
+ model_id = "resnet18.a1_in1k"
524
+ model_dir = self .data_dir / "timm_models" / "converted_models"
525
+ _ , model_cfg = self .get_timm_model (model_id , model_dir )
526
+
527
+ ir_path = model_dir / model_id / "model_fp32.xml"
528
+ model = ov .Core ().read_model (ir_path )
529
+
530
+ mean_values = [(item * 255 ) for item in model_cfg ["mean" ]]
531
+ scale_values = [(item * 255 ) for item in model_cfg ["std" ]]
532
+ preprocess_fn = get_preprocess_fn (
533
+ change_channel_order = True ,
534
+ input_size = model_cfg ["input_size" ][1 :],
535
+ mean = mean_values ,
536
+ std = scale_values ,
537
+ hwc_to_chw = True ,
538
+ )
539
+
540
+ explainer = Explainer (
541
+ model = model ,
542
+ task = Task .CLASSIFICATION ,
543
+ preprocess_fn = preprocess_fn ,
544
+ postprocess_fn = get_postprocess_fn (),
545
+ explain_mode = explain_mode ,
546
+ explain_method = explain_method ,
547
+ embed_scaling = False ,
548
+ )
549
+
550
+ target_class = self .supported_num_classes [model_cfg ["num_classes" ]]
551
+ image = cv2 .imread ("tests/assets/cheetah_person.jpg" )
552
+ explanation = explainer (
553
+ image ,
554
+ original_input_image = image ,
555
+ targets = [target_class ],
556
+ resize = False ,
557
+ colormap = False ,
558
+ )
559
+
560
+ if explain_method == Method .ACTIVATIONMAP :
561
+ generated_map = explanation .saliency_map ["per_image_map" ]
562
+ else :
563
+ generated_map = explanation .saliency_map [target_class ]
564
+
565
+ reference_maps_path = Path ("tests/assets/reference_maps" )
566
+ reference_map = np .load (reference_maps_path / self .reference_maps_names [(explain_mode , explain_method )])
567
+ assert np .all (np .abs (generated_map .astype (np .int16 ) - reference_map .astype (np .int16 )) <= 3 )
568
+
507
569
def check_for_saved_map (self , model_id , directory ):
508
570
for target in self .supported_num_classes .values ():
509
571
map_name = model_id + "_target_" + str (target ) + ".jpg"
0 commit comments