Skip to content

Commit 0654107

Browse files
committedDec 17, 2019
Add evaluation code
1 parent 8e5914f commit 0654107

File tree

4 files changed

+184
-12
lines changed

4 files changed

+184
-12
lines changed
 

‎test.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from anchor import generate_default_boxes
1010
from box_utils import decode, compute_nms
11-
from data import create_batch_generator
11+
from voc_data import create_batch_generator
1212
from image_utils import ImageVisualizer
1313
from losses import create_losses
1414
from network import create_ssd
@@ -118,13 +118,12 @@ def predict(imgs, default_boxes):
118118
visualizer.save_image(
119119
original_image, boxes, classes, '{}.jpg'.format(filename))
120120

121-
log_file = os.path.join('outputs/detects', '{}.txt'.format(filename))
121+
log_file = os.path.join('outputs/detects', '{}.txt')
122122

123-
with open(log_file, 'w') as f:
124-
log = []
125-
for cls, box, score in zip(classes, boxes, scores):
126-
cls_name = info['idx_to_name'][cls - 1]
127-
log.append(
128-
','.join([cls_name, *[str(c) for c in box], str(score)]))
129-
log = '\n'.join(log)
130-
f.write(log)
123+
for cls, box, score in zip(classes, boxes, scores):
124+
cls_name = info['idx_to_name'][cls - 1]
125+
with open(log_file.format(cls_name), 'a') as f:
126+
f.write('{} {} {} {} {} {}\n'.format(
127+
filename,
128+
score,
129+
*[coord for coord in box]))

‎train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import yaml
77

88
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
9-
from data import create_batch_generator
9+
from voc_data import create_batch_generator
1010
from anchor import generate_default_boxes
1111
from network import create_ssd
1212
from losses import create_losses
@@ -110,7 +110,7 @@ def train_step(imgs, gt_confs, gt_locs, ssd, criterion, optimizer):
110110
avg_loss = (avg_loss * i + loss.numpy()) / (i + 1)
111111
avg_conf_loss = (avg_conf_loss * i + conf_loss.numpy()) / (i + 1)
112112
avg_loc_loss = (avg_loc_loss * i + loc_loss.numpy()) / (i + 1)
113-
if (i + 1) % 2 == 0:
113+
if (i + 1) % 50 == 0:
114114
print('Epoch: {} Batch {} Time: {:.2}s | Loss: {:.4f} Conf: {:.4f} Loc: {:.4f}'.format(
115115
epoch + 1, i + 1, time.time() - start, avg_loss, avg_conf_loss, avg_loc_loss))
116116

‎data.py ‎voc_data.py

File renamed without changes.

‎voc_eval.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
import numpy as np
3+
import xml.etree.ElementTree as ET
4+
import argparse
5+
6+
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument('--data-dir', default='../dataset')
9+
parser.add_argument('--data-year', default='2007')
10+
parser.add_argument('--detect-dir', default='./outputs/detects')
11+
parser.add_argument('--use-07-metric', type=bool, default=False)
12+
args = parser.parse_args()
13+
14+
15+
def get_annotation(anno_file):
16+
tree = ET.parse(anno_file)
17+
objects = []
18+
for obj in tree.findall('object'):
19+
obj_struct = {}
20+
obj_struct['name'] = obj.find('name').text
21+
obj_struct['pose'] = obj.find('pose').text
22+
obj_struct['truncated'] = int(obj.find('truncated').text)
23+
obj_struct['difficult'] = int(obj.find('difficult').text)
24+
bbox = obj.find('bndbox')
25+
obj_struct['bbox'] = [int(bbox.find('xmin').text),
26+
int(bbox.find('ymin').text),
27+
int(bbox.find('xmax').text),
28+
int(bbox.find('ymax').text)]
29+
objects.append(obj_struct)
30+
31+
return objects
32+
33+
34+
def compute_ap(rec, prec, ap, use_07_metric=False):
35+
if use_07_metric:
36+
ap = 0.0
37+
for t in np.arange(0.0, 1.1, 0.1):
38+
if np.sum(rec >= t) == 0:
39+
p = 0
40+
else:
41+
p = np.max(prec[rec >= t])
42+
ap = ap + p / 11.0
43+
else:
44+
mrec = np.concatenate(([0.0], rec, [1.0]))
45+
mprec = np.concatenate(([0.0], prec, [0.0]))
46+
47+
for i in range(mprec.size - 1, 0, -1):
48+
mprec[i - 1] = np.maximum(mprec[i - 1], mprec[i])
49+
50+
i = np.where(mrec[1:] != mrec[:-1])[0]
51+
52+
ap = np.sum((mrec[i + 1] - mrec[i]) * mprec[i + 1])
53+
54+
return ap
55+
56+
57+
def voc_eval(det_path, anno_path, cls_name, iou_thresh=0.5, use_07_metric=False):
58+
det_file = det_path.format(cls_name)
59+
with open(det_file, 'r') as f:
60+
lines = f.readlines()
61+
62+
lines = [x.strip().split(' ') for x in lines]
63+
image_ids = [x[0] for x in lines]
64+
confs = np.array([float(x[1]) for x in lines])
65+
boxes = np.array([[float(z) for z in x[2:]] for x in lines])
66+
67+
gts = {}
68+
cls_gts = {}
69+
npos = 0
70+
for image_id in image_ids:
71+
gts[image_id] = get_annotation(anno_path.format(image_id))
72+
R = [obj for obj in gts[image_id] if obj['name'] == cls_name]
73+
gt_boxes = np.array([x['bbox'] for x in R])
74+
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
75+
det = [False] * len(R)
76+
npos = npos + sum(~difficult)
77+
cls_gts[image_id] = {
78+
'gt_boxes': gt_boxes,
79+
'difficult': difficult,
80+
'det': det
81+
}
82+
83+
sorted_ids = np.argsort(-confs)
84+
sorted_scores = np.sort(-confs)
85+
boxes = boxes[sorted_ids, :]
86+
image_ids = [image_ids[x] for x in sorted_ids]
87+
88+
nd = len(image_ids)
89+
tp = np.zeros(nd)
90+
fp = np.zeros(nd)
91+
for d in range(nd):
92+
R = cls_gts[image_ids[d]]
93+
box = boxes[d, :].astype(float)
94+
iou_max = -np.inf
95+
gt_box = R['gt_boxes'].astype(float)
96+
97+
if gt_box.size > 0:
98+
ixmin = np.maximum(gt_box[:, 0], box[0])
99+
ixmax = np.maximum(gt_box[:, 2], box[2])
100+
iymin = np.maximum(gt_box[:, 1], box[1])
101+
iymax = np.maximum(gt_box[:, 3], box[3])
102+
iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
103+
ih = np.maximum(iymax - iymin + 1.0, 0.0)
104+
inters = iw * ih
105+
106+
uni = ((box[2] - box[0] + 1.0) * (box[3] - box[1] + 1.0) +
107+
(gt_box[:, 2] - gt_box[:, 0] + 1.0) *
108+
(gt_box[:, 3] - gt_box[:, 1] + 1.0) - inters)
109+
110+
ious = inters / uni
111+
iou_max = np.max(ious)
112+
jmax = np.argmax(ious)
113+
114+
if iou_max > iou_thresh:
115+
if not R['difficult'][jmax]:
116+
if not R['det'][jmax]:
117+
tp[d] = 1.0
118+
R['det'][jmax] = 1
119+
else:
120+
fp[d] = 1.0
121+
else:
122+
fp[d] = 1.0
123+
124+
fp = np.cumsum(fp)
125+
tp = np.cumsum(tp)
126+
recall = tp / float(npos)
127+
precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
128+
129+
ap = compute_ap(recall, precision, use_07_metric)
130+
131+
return recall, precision, ap
132+
133+
134+
if __name__ == '__main__':
135+
aps = {
136+
'aeroplane': 0.0,
137+
'bicycle': 0.0,
138+
'bird': 0.0,
139+
'boat': 0.0,
140+
'bottle': 0.0,
141+
'bus': 0.0,
142+
'car': 0.0,
143+
'cat': 0.0,
144+
'chair': 0.0,
145+
'cow': 0.0,
146+
'diningtable': 0.0,
147+
'dog': 0.0,
148+
'horse': 0.0,
149+
'motorbike': 0.0,
150+
'person': 0.0,
151+
'pottedplant': 0.0,
152+
'sheep': 0.0,
153+
'sofa': 0.0,
154+
'train': 0.0,
155+
'tvmonitor': 0.0,
156+
'mAP': []
157+
}
158+
for cls_name in aps.keys():
159+
det_path = os.path.join(args.detect_dir, '{}.txt')
160+
anno_path = os.path.join(
161+
args.data_dir, 'VOC{}'.format(args.data_year), 'Annotations', '{}.xml')
162+
if os.path.exists(det_path.format(cls_name)):
163+
recall, precision, ap = voc_eval(
164+
det_path,
165+
anno_path,
166+
cls_name,
167+
use_07_metric=args.use_07_metric)
168+
aps[cls_name] = ap
169+
aps['mAP'].append(ap)
170+
171+
aps['mAP'] = np.mean(aps['mAP'])
172+
for key, value in aps.items():
173+
print('{}: {}'.format(key, value))

0 commit comments

Comments
 (0)
Please sign in to comment.