Skip to content

Commit 0ba6bee

Browse files
authoredFeb 15, 2024
Reorder saliency maps for h-cls according to labels in .xml meta (#156)
* Reorder saliency maps for h-cls according to labels in .xml meta * Fix pre-commit * Fix isort * Fixes from comments * Typo * Rely on `class_to_group_idx` from model instead of `label_to_idx`
1 parent 0045fb3 commit 0ba6bee

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed
 

‎model_api/python/openvino/model_api/models/classification.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (c) 2021-2023 Intel Corporation
2+
Copyright (c) 2021-2024 Intel Corporation
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import json
18+
from typing import Dict
1819

1920
import numpy as np
2021
from openvino.preprocess import PrePostProcessor
@@ -170,11 +171,29 @@ def postprocess(self, outputs, meta):
170171

171172
return ClassificationResult(
172173
result,
173-
outputs.get(_saliency_map_name, np.ndarray(0)),
174+
self.get_saliency_maps(outputs),
174175
outputs.get(_feature_vector_name, np.ndarray(0)),
175176
raw_scores,
176177
)
177178

179+
def get_saliency_maps(self, outputs: Dict) -> np.ndarray:
180+
"""
181+
Returns saliency map model output. In hierarchical case reorders saliency maps
182+
to match the order of labels in .XML meta.
183+
"""
184+
saliency_maps = outputs.get(_saliency_map_name, np.ndarray(0))
185+
if not self.hierarchical:
186+
return saliency_maps
187+
188+
reordered_saliency_maps = [[] for _ in range(len(saliency_maps))]
189+
model_classes = self.hierarchical_info["cls_heads_info"]["class_to_group_idx"]
190+
label_to_model_out_idx = {lbl: i for i, lbl in enumerate(model_classes.keys())}
191+
for batch in range(len(saliency_maps)):
192+
for label in self.labels:
193+
idx = label_to_model_out_idx[label]
194+
reordered_saliency_maps[batch].append(saliency_maps[batch][idx])
195+
return np.array(reordered_saliency_maps)
196+
178197
def get_all_probs(self, logits: np.ndarray):
179198
if self.multilabel:
180199
probs = sigmoid_numpy(logits.reshape(-1))

0 commit comments

Comments
 (0)