Skip to content

Commit 9092d56

Browse files
committed
add detection benchmark
1 parent d50dfca commit 9092d56

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

include/benchmark.hpp

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
3+
#include <vector>
4+
5+
#include "opencv2/opencv.hpp"
6+
7+
class DetectionQualityEvaluator {
8+
protected:
9+
size_t num_objects_;
10+
size_t num_objects_found_;
11+
size_t num_responses_;
12+
size_t num_false_alarms_;
13+
size_t num_frames_;
14+
float threshold_;
15+
16+
public:
17+
DetectionQualityEvaluator();
18+
19+
void UpdateMetrics(const std::vector<cv::Rect>& guess,
20+
const std::vector<cv::Rect>& ground_truth);
21+
void UpdateMetrics(const std::vector<cv::Rect>& guess,
22+
const std::vector<double>& scores,
23+
const std::vector<cv::Rect>& ground_truth);
24+
25+
float GetDetectionRate() const;
26+
float GetFalseAlarmRate() const;
27+
static float IntersectionOverUnion(const cv::Rect& r, const cv::Rect& p);
28+
};
29+
30+
class GroundTruthReader {
31+
protected:
32+
cv::FileStorage file_storage_;
33+
cv::FileNode file_node_;
34+
cv::FileNodeIterator file_node_iter_;
35+
bool is_opened_;
36+
37+
public:
38+
GroundTruthReader();
39+
void Open(const std::string& filename);
40+
bool Next(std::vector<cv::Rect>& rect);
41+
bool Get(std::vector<cv::Rect>& rect);
42+
bool IsOpen() const;
43+
};

src/benchmark.cpp

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#include "metrics.hpp"
2+
3+
#include <algorithm>
4+
#include <iostream>
5+
#include <utility>
6+
#include <vector>
7+
8+
#include "opencv2/opencv.hpp"
9+
10+
using namespace cv;
11+
using namespace std;
12+
13+
float PrecisionRecallEvaluator::IntersectionOverUnion(const Rect& r,
14+
const Rect& p) {
15+
float intersection_area = (r & p).area();
16+
float union_area = r.area() + p.area() - intersection_area;
17+
float iou = union_area > 0 ? intersection_area / union_area : 0.0f;
18+
return iou;
19+
}
20+
21+
PrecisionRecallEvaluator::PrecisionRecallEvaluator() {
22+
num_objects_ = 0;
23+
num_objects_found_ = 0;
24+
num_responses_ = 0;
25+
num_false_alarms_ = 0;
26+
num_frames_ = 0;
27+
threshold_ = 0.5f;
28+
}
29+
30+
void PrecisionRecallEvaluator::UpdateMetrics(const vector<Rect>& guess,
31+
const vector<Rect>& ground_truth) {
32+
num_frames_++;
33+
num_objects_ += ground_truth.size();
34+
num_responses_ += guess.size();
35+
36+
vector<bool> objects_detected(ground_truth.size(), false);
37+
vector<bool> correct_detections(guess.size(), false);
38+
39+
for (size_t j = 0; j < guess.size(); ++j) {
40+
// Do not count any already matched detector alarm twice.
41+
if (correct_detections.at(j)) {
42+
continue;
43+
}
44+
const Rect& alarm = guess.at(j);
45+
for (size_t i = 0; i < ground_truth.size(); ++i) {
46+
// Do not allow several matches for one detector.
47+
if (objects_detected.at(i)) {
48+
continue;
49+
}
50+
const Rect& gt = ground_truth.at(i);
51+
if (IntersectionOverUnion(alarm, gt) >= threshold_) {
52+
objects_detected[i] = true;
53+
correct_detections[j] = true;
54+
break;
55+
}
56+
}
57+
}
58+
59+
num_objects_found_ +=
60+
std::count(objects_detected.begin(), objects_detected.end(), true);
61+
num_false_alarms_ +=
62+
std::count(correct_detections.begin(), correct_detections.end(), false);
63+
}
64+
65+
void PrecisionRecallEvaluator::UpdateMetrics(const vector<Rect>& guess,
66+
const vector<double>& scores,
67+
const vector<Rect>& ground_truth) {
68+
if (guess.size() != scores.size()) {
69+
cerr << "Check failed 'guess.size() == scores.size()' (" << guess.size()
70+
<< " vs " << scores.size() << ")." << endl;
71+
return;
72+
}
73+
// Sort detector alarm by scores.
74+
vector<size_t> idx(guess.size());
75+
iota(idx.begin(), idx.end(), 0);
76+
std::sort(idx.begin(), idx.end(),
77+
[&](size_t i, size_t j) { return scores.at(i) > scores.at(j); });
78+
vector<Rect> guess_sorted(guess.size());
79+
std::transform(idx.begin(), idx.end(), guess_sorted.begin(),
80+
[&](size_t i) { return guess.at(i); });
81+
// Evaluate metrics.
82+
UpdateMetrics(guess_sorted, ground_truth);
83+
}
84+
85+
float PrecisionRecallEvaluator::GetDetectionRate() const {
86+
return num_objects_ == 0
87+
? 0.0f
88+
: num_objects_found_ / static_cast<float>(num_objects_);
89+
}
90+
91+
float PrecisionRecallEvaluator::GetFalseAlarmRate() const {
92+
return num_responses_ == 0
93+
? 0.0f
94+
: num_false_alarms_ / static_cast<float>(num_responses_);
95+
}
96+
97+
GroundTruthReader::GroundTruthReader() {
98+
is_opened_ = false;
99+
}
100+
101+
void GroundTruthReader::Open(const string& filename) {
102+
try {
103+
is_opened_ = file_storage_.open(filename, FileStorage::READ);
104+
} catch (exception&) {
105+
cerr << "Failed to read ground truth file." << endl;
106+
is_opened_ = false;
107+
}
108+
if (is_opened_) {
109+
file_node_ = file_storage_["objects"];
110+
CV_Assert(file_node_.isSeq());
111+
is_opened_ = !(file_node_.empty() || file_node_.isNone());
112+
if (is_opened_) {
113+
file_node_iter_ = file_node_.begin();
114+
}
115+
}
116+
}
117+
118+
bool GroundTruthReader::Next(vector<Rect>& rects) {
119+
if (file_node_iter_ != file_node_.end()) {
120+
++file_node_iter_;
121+
}
122+
return Get(rects);
123+
}
124+
125+
bool GroundTruthReader::Get(vector<Rect>& rects) {
126+
rects.clear();
127+
if (!is_opened_ || file_node_iter_ == file_node_.end()) {
128+
return false;
129+
} else {
130+
CV_Assert((*file_node_iter_).isSeq());
131+
for (auto i : *file_node_iter_) {
132+
Rect rect;
133+
i >> rect;
134+
rects.emplace_back(rect);
135+
}
136+
return true;
137+
}
138+
}
139+
140+
bool GroundTruthReader::IsOpen() const { return is_opened_; }

0 commit comments

Comments
 (0)