|
1 | 1 | """
|
2 |
| - Copyright (c) 2021-2023 Intel Corporation |
| 2 | + Copyright (c) 2021-2024 Intel Corporation |
3 | 3 |
|
4 | 4 | Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | you may not use this file except in compliance with the License.
|
|
15 | 15 | """
|
16 | 16 |
|
17 | 17 | import json
|
| 18 | +from typing import Dict |
18 | 19 |
|
19 | 20 | import numpy as np
|
20 | 21 | from openvino.preprocess import PrePostProcessor
|
@@ -170,11 +171,29 @@ def postprocess(self, outputs, meta):
|
170 | 171 |
|
171 | 172 | return ClassificationResult(
|
172 | 173 | result,
|
173 |
| - outputs.get(_saliency_map_name, np.ndarray(0)), |
| 174 | + self.get_saliency_maps(outputs), |
174 | 175 | outputs.get(_feature_vector_name, np.ndarray(0)),
|
175 | 176 | raw_scores,
|
176 | 177 | )
|
177 | 178 |
|
| 179 | + def get_saliency_maps(self, outputs: Dict) -> np.ndarray: |
| 180 | + """ |
| 181 | + Returns saliency map model output. In hierarchical case reorders saliency maps |
| 182 | + to match the order of labels in .XML meta. |
| 183 | + """ |
| 184 | + saliency_maps = outputs.get(_saliency_map_name, np.ndarray(0)) |
| 185 | + if not self.hierarchical: |
| 186 | + return saliency_maps |
| 187 | + |
| 188 | + reordered_saliency_maps = [[] for _ in range(len(saliency_maps))] |
| 189 | + model_classes = self.hierarchical_info["cls_heads_info"]["class_to_group_idx"] |
| 190 | + label_to_model_out_idx = {lbl: i for i, lbl in enumerate(model_classes.keys())} |
| 191 | + for batch in range(len(saliency_maps)): |
| 192 | + for label in self.labels: |
| 193 | + idx = label_to_model_out_idx[label] |
| 194 | + reordered_saliency_maps[batch].append(saliency_maps[batch][idx]) |
| 195 | + return np.array(reordered_saliency_maps) |
| 196 | + |
178 | 197 | def get_all_probs(self, logits: np.ndarray):
|
179 | 198 | if self.multilabel:
|
180 | 199 | probs = sigmoid_numpy(logits.reshape(-1))
|
|
0 commit comments