|
1 | 1 | """
|
2 |
| - Copyright (c) 2021-2023 Intel Corporation |
| 2 | + Copyright (c) 2021-2024 Intel Corporation |
3 | 3 |
|
4 | 4 | Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | you may not use this file except in compliance with the License.
|
|
15 | 15 | """
|
16 | 16 |
|
17 | 17 | import json
|
| 18 | +from typing import Dict |
18 | 19 |
|
19 | 20 | import numpy as np
|
| 21 | + |
20 | 22 | from openvino.preprocess import PrePostProcessor
|
21 | 23 | from openvino.runtime import Model, Type
|
22 | 24 | from openvino.runtime import opset10 as opset
|
@@ -170,11 +172,23 @@ def postprocess(self, outputs, meta):
|
170 | 172 |
|
171 | 173 | return ClassificationResult(
|
172 | 174 | result,
|
173 |
| - outputs.get(_saliency_map_name, np.ndarray(0)), |
| 175 | + self.get_saliency_maps(outputs), |
174 | 176 | outputs.get(_feature_vector_name, np.ndarray(0)),
|
175 | 177 | raw_scores,
|
176 | 178 | )
|
177 | 179 |
|
| 180 | + def get_saliency_maps(self, outputs: Dict) -> np.ndarray: |
| 181 | + saliency_maps = outputs.get(_saliency_map_name, np.ndarray(0)) |
| 182 | + if not self.hierarchical: |
| 183 | + return saliency_maps |
| 184 | + # In hierarchical case reorder saliency maps to match the order of labels in .XML meta |
| 185 | + reordered_saliency_maps = [[] for _ in range(len(saliency_maps))] |
| 186 | + for batch in range(len(saliency_maps)): |
| 187 | + for label in self.labels: |
| 188 | + idx = self.hierarchical_info['cls_heads_info']['label_to_idx'][label] |
| 189 | + reordered_saliency_maps[batch].append(saliency_maps[batch][idx]) |
| 190 | + return np.array(reordered_saliency_maps) |
| 191 | + |
178 | 192 | def get_all_probs(self, logits: np.ndarray):
|
179 | 193 | if self.multilabel:
|
180 | 194 | probs = sigmoid_numpy(logits.reshape(-1))
|
|
0 commit comments