Skip to content

Commit 114e3c9

Browse files
authored
Fix saliency maps order for h-cls in cpp (#159)
1 parent 41c2026 commit 114e3c9

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

model_api/cpp/models/include/models/classification_model.h

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct HierarchicalConfig {
3737
std::vector<std::pair<std::string, std::string>> label_tree_edges;
3838
std::vector<std::vector<std::string>> all_groups;
3939
std::map<size_t, std::pair<size_t,size_t>> head_idx_to_logits_range;
40+
std::map<size_t, std::string> logit_idx_to_label;
4041
size_t num_multiclass_heads;
4142
size_t num_multilabel_heads;
4243
size_t num_single_label_classes;
@@ -90,4 +91,6 @@ class ClassificationModel : public ImageModel {
9091
std::unique_ptr<ResultBase> get_multilabel_predictions(InferenceResult& infResult, bool add_raw_scores);
9192
std::unique_ptr<ResultBase> get_multiclass_predictions(InferenceResult& infResult, bool add_raw_scores);
9293
std::unique_ptr<ResultBase> get_hierarchical_predictions(InferenceResult& infResult, bool add_raw_scores);
94+
ov::Tensor reorder_saliency_maps(const ov::Tensor&);
95+
9396
};

model_api/cpp/models/src/classification_model.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ std::unique_ptr<ResultBase> ClassificationModel::postprocess(InferenceResult& in
289289
auto saliency_map_iter = infResult.outputsData.find(saliency_map_name);
290290
if (saliency_map_iter != infResult.outputsData.end()) {
291291
cls_res->saliency_map = std::move(saliency_map_iter->second);
292+
cls_res->saliency_map = reorder_saliency_maps(cls_res->saliency_map);
292293
}
293294
auto feature_vector_iter = infResult.outputsData.find(feature_vector_name);
294295
if (feature_vector_iter != infResult.outputsData.end()) {
@@ -384,6 +385,27 @@ std::unique_ptr<ResultBase> ClassificationModel::get_hierarchical_predictions(In
384385
return retVal;
385386
}
386387

388+
ov::Tensor ClassificationModel::reorder_saliency_maps(const ov::Tensor& source_maps) {
389+
if (!hierarchical || source_maps.get_shape().size() == 1) {
390+
return source_maps;
391+
}
392+
393+
auto reordered_maps = ov::Tensor(source_maps.get_element_type(), source_maps.get_shape());
394+
const std::uint8_t* source_maps_ptr = static_cast<std::uint8_t*>(source_maps.data());
395+
std::uint8_t* reordered_maps_ptr = static_cast<std::uint8_t*>(reordered_maps.data());
396+
397+
size_t shape_offset = (source_maps.get_shape().size() == 4) ? 1 : 0;
398+
size_t map_byte_size = source_maps.get_element_type().size() *
399+
source_maps.get_shape()[shape_offset + 1] * source_maps.get_shape()[shape_offset + 2];
400+
401+
for (size_t i = 0; i < source_maps.get_shape()[shape_offset]; ++i) {
402+
size_t new_index = hierarchical_info.label_to_idx[hierarchical_info.logit_idx_to_label[i]];
403+
std::copy_n(source_maps_ptr + i*map_byte_size, map_byte_size, reordered_maps_ptr + new_index * map_byte_size);
404+
}
405+
406+
return reordered_maps;
407+
}
408+
387409
std::unique_ptr<ResultBase> ClassificationModel::get_multiclass_predictions(InferenceResult& infResult, bool add_raw_scores) {
388410
const ov::Tensor& indicesTensor = infResult.outputsData.find(indices_name)->second;
389411
const int* indicesPtr = indicesTensor.data<int>();
@@ -516,6 +538,17 @@ HierarchicalConfig::HierarchicalConfig(const std::string& json_repr) {
516538
for (const auto& range_descr : tmp_head_idx_to_logits_range) {
517539
head_idx_to_logits_range[stoi(range_descr.first)] = range_descr.second;
518540
}
541+
542+
size_t logits_processed = 0;
543+
for (size_t i = 0; i < num_multiclass_heads; ++i) {
544+
const auto& logits_range = head_idx_to_logits_range[i];
545+
for (size_t k = logits_range.first; k < logits_range.second; ++k) {
546+
logit_idx_to_label[logits_processed++] = all_groups[i][k - logits_range.first];
547+
}
548+
}
549+
for (size_t i = 0; i < num_multilabel_heads; ++i) {
550+
logit_idx_to_label[logits_processed++] = all_groups[num_multiclass_heads + i][0];
551+
}
519552
}
520553

521554
GreedyLabelsResolver::GreedyLabelsResolver(const HierarchicalConfig& config) :

0 commit comments

Comments
 (0)