Skip to content

Commit 1fcfd86

Browse files
committed
Fix saliency maps order for h-cls in cpp
1 parent c3998dc commit 1fcfd86

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

model_api/cpp/models/include/models/classification_model.h

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct HierarchicalConfig {
3737
std::vector<std::pair<std::string, std::string>> label_tree_edges;
3838
std::vector<std::vector<std::string>> all_groups;
3939
std::map<size_t, std::pair<size_t,size_t>> head_idx_to_logits_range;
40+
std::map<size_t, std::string> logit_idx_to_label;
4041
size_t num_multiclass_heads;
4142
size_t num_multilabel_heads;
4243
size_t num_single_label_classes;
@@ -90,4 +91,6 @@ class ClassificationModel : public ImageModel {
9091
std::unique_ptr<ResultBase> get_multilabel_predictions(InferenceResult& infResult, bool add_raw_scores);
9192
std::unique_ptr<ResultBase> get_multiclass_predictions(InferenceResult& infResult, bool add_raw_scores);
9293
std::unique_ptr<ResultBase> get_hierarchical_predictions(InferenceResult& infResult, bool add_raw_scores);
94+
ov::Tensor reorder_saliency_maps(const ov::Tensor&);
95+
9396
};

model_api/cpp/models/src/classification_model.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ std::unique_ptr<ResultBase> ClassificationModel::postprocess(InferenceResult& in
263263
auto saliency_map_iter = infResult.outputsData.find(saliency_map_name);
264264
if (saliency_map_iter != infResult.outputsData.end()) {
265265
cls_res->saliency_map = std::move(saliency_map_iter->second);
266+
cls_res->saliency_map = reorder_saliency_maps(cls_res->saliency_map);
266267
}
267268
auto feature_vector_iter = infResult.outputsData.find(feature_vector_name);
268269
if (feature_vector_iter != infResult.outputsData.end()) {
@@ -358,6 +359,27 @@ std::unique_ptr<ResultBase> ClassificationModel::get_hierarchical_predictions(In
358359
return retVal;
359360
}
360361

362+
ov::Tensor ClassificationModel::reorder_saliency_maps(const ov::Tensor& source_maps) {
363+
if (!hierarchical || source_maps.get_shape().size() == 1) {
364+
return source_maps;
365+
}
366+
367+
auto reordered_maps = ov::Tensor(source_maps.get_element_type(), source_maps.get_shape());
368+
const std::uint8_t* source_maps_ptr = static_cast<std::uint8_t*>(source_maps.data());
369+
std::uint8_t* reordered_maps_ptr = static_cast<std::uint8_t*>(reordered_maps.data());
370+
371+
size_t shape_offset = (source_maps.get_shape().size() == 4) ? 1 : 0;
372+
size_t map_byte_size = source_maps.get_element_type().size() *
373+
source_maps.get_shape()[shape_offset + 1] * source_maps.get_shape()[shape_offset + 2];
374+
375+
for (size_t i = 0; i < source_maps.get_shape()[shape_offset]; ++i) {
376+
size_t new_index = hierarchical_info.label_to_idx[hierarchical_info.logit_idx_to_label[i]];
377+
std::copy_n(source_maps_ptr + i*map_byte_size, map_byte_size, reordered_maps_ptr + new_index * map_byte_size);
378+
}
379+
380+
return reordered_maps;
381+
}
382+
361383
std::unique_ptr<ResultBase> ClassificationModel::get_multiclass_predictions(InferenceResult& infResult, bool add_raw_scores) {
362384
const ov::Tensor& indicesTensor = infResult.outputsData.find(indices_name)->second;
363385
const int* indicesPtr = indicesTensor.data<int>();
@@ -490,6 +512,17 @@ HierarchicalConfig::HierarchicalConfig(const std::string& json_repr) {
490512
for (const auto& range_descr : tmp_head_idx_to_logits_range) {
491513
head_idx_to_logits_range[stoi(range_descr.first)] = range_descr.second;
492514
}
515+
516+
size_t logits_processed = 0;
517+
for (size_t i = 0; i < num_multiclass_heads; ++i) {
518+
const auto& logits_range = head_idx_to_logits_range[i];
519+
for (size_t k = logits_range.first; k < logits_range.second; ++k) {
520+
logit_idx_to_label[logits_processed++] = all_groups[i][k - logits_range.first];
521+
}
522+
}
523+
for (size_t i = 0; i < num_multilabel_heads; ++i) {
524+
logit_idx_to_label[logits_processed++] = all_groups[num_multiclass_heads + i][0];
525+
}
493526
}
494527

495528
GreedyLabelsResolver::GreedyLabelsResolver(const HierarchicalConfig& config) :

0 commit comments

Comments
 (0)