Skip to content

Commit 98bdd41

Browse files
authored
Update v5lite.py
1 parent 5d64b99 commit 98bdd41

File tree

1 file changed

+26
-73
lines changed

1 file changed

+26
-73
lines changed

python_demo/onnxruntime/v5lite.py

+26-73
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,17 @@
55
import onnxruntime as ort
66

77
class yolov5_lite():
8-
def __init__(self, model_pb_path, label_path, confThreshold=0.5, nmsThreshold=0.5, objThreshold=0.5):
8+
def __init__(self, model_pb_path, label_path, confThreshold=0.5, nmsThreshold=0.5):
99
so = ort.SessionOptions()
1010
so.log_severity_level = 3
1111
self.net = ort.InferenceSession(model_pb_path, so)
1212
self.classes = list(map(lambda x: x.strip(), open(label_path, 'r').readlines()))
13-
self.num_classes = len(self.classes)
14-
anchors = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
15-
self.nl = len(anchors)
16-
self.na = len(anchors[0]) // 2
17-
self.no = self.num_classes + 5
18-
self.grid = [np.zeros(1)] * self.nl
19-
self.stride = np.array([8., 16., 32.])
20-
self.anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(self.nl, -1, 2)
2113

2214
self.confThreshold = confThreshold
2315
self.nmsThreshold = nmsThreshold
24-
self.objThreshold = objThreshold
2516
self.input_shape = (self.net.get_inputs()[0].shape[2], self.net.get_inputs()[0].shape[3])
2617

27-
def resize_image(self, srcimg, keep_ratio=True):
18+
def letterBox(self, srcimg, keep_ratio=True):
2819
top, left, newh, neww = 0, 0, self.input_shape[0], self.input_shape[1]
2920
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
3021
hw_scale = srcimg.shape[0] / srcimg.shape[1]
@@ -43,90 +34,57 @@ def resize_image(self, srcimg, keep_ratio=True):
4334
img = cv2.resize(srcimg, self.input_shape, interpolation=cv2.INTER_AREA)
4435
return img, newh, neww, top, left
4536

46-
def _make_grid(self, nx=20, ny=20):
47-
xv, yv = np.meshgrid(np.arange(ny), np.arange(nx))
48-
return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32)
49-
5037
def postprocess(self, frame, outs, pad_hw):
5138
newh, neww, padh, padw = pad_hw
5239
frameHeight = frame.shape[0]
5340
frameWidth = frame.shape[1]
5441
ratioh, ratiow = frameHeight / newh, frameWidth / neww
55-
# Scan through all the bounding boxes output from the network and keep only the
56-
# ones with high confidence scores. Assign the box's class label as the class with the highest score.
5742
classIds = []
5843
confidences = []
59-
box_index = []
6044
boxes = []
61-
outs = outs[outs[:, 4] > self.objThreshold]
6245
for detection in outs:
63-
scores = detection[5:]
64-
classId = np.argmax(scores)
65-
confidence = scores[classId]
66-
if confidence > self.confThreshold: # and detection[4] > self.objThreshold:
67-
center_x = int((detection[0] - padw) * ratiow)
68-
center_y = int((detection[1] - padh) * ratioh)
69-
width = int(detection[2] * ratiow)
70-
height = int(detection[3] * ratioh)
71-
left = int(center_x - width / 2)
72-
top = int(center_y - height / 2)
46+
scores, classId = detection[4], detection[5]
47+
if scores > self.confThreshold: # and detection[4] > self.objThreshold:
48+
x1 = int((detection[0] - padw) * ratiow)
49+
y1 = int((detection[1] - padh) * ratioh)
50+
x2 = int((detection[2] - padw) * ratiow)
51+
y2 = int((detection[3] - padh) * ratioh)
7352
classIds.append(classId)
74-
confidences.append(float(confidence))
75-
boxes.append([left, top, width, height])
53+
confidences.append(scores)
54+
boxes.append([x1, y1, x2, y2])
7655

77-
# Perform non maximum suppression to eliminate redundant overlapping boxes with
78-
# lower confidences.
79-
print(boxes)
56+
# # Perform non maximum suppression to eliminate redundant overlapping boxes with
57+
# # lower confidences.
8058
indices = cv2.dnn.NMSBoxes(boxes, confidences, self.confThreshold, self.nmsThreshold)
8159

82-
for i in indices:
83-
box_index.append(i[0])
84-
85-
for i in box_index:
86-
box = boxes[i]
87-
left = box[0]
88-
top = box[1]
89-
width = box[2]
90-
height = box[3]
91-
frame = self.drawPred(frame, classIds[i], confidences[i], left, top, left + width, top + height)
60+
for ind in indices:
61+
frame = self.drawPred(frame, classIds[ind], confidences[ind], boxes[ind][0], boxes[ind][1], boxes[ind][2], boxes[ind][3])
9262
return frame
9363

94-
def drawPred(self, frame, classId, conf, left, top, right, bottom):
64+
def drawPred(self, frame, classId, conf, x1, y1, x2, y2):
9565
# Draw a bounding box.
96-
cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=2)
66+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), thickness=2)
9767

9868
label = '%.2f' % conf
99-
label = '%s:%s' % (self.classes[classId], label)
69+
text = '%s:%s' % (self.classes[int(classId)], label)
10070

10171
# Display the label at the top of the bounding box
102-
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
103-
top = max(top, labelSize[1])
104-
# cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED)
105-
cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_TRIPLEX, 0.5, (0, 255, 0), thickness=1)
72+
labelSize, baseLine = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
73+
y1 = max(y1, labelSize[1])
74+
cv2.putText(frame, text, (x1, y1 - 10), cv2.FONT_HERSHEY_TRIPLEX, 0.5, (0, 255, 0), thickness=1)
10675
return frame
10776

10877
def detect(self, srcimg):
109-
img, newh, neww, top, left = self.resize_image(srcimg)
78+
img, newh, neww, top, left = self.letterBox(srcimg)
11079
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
11180
img = img.astype(np.float32) / 255.0
11281
blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
11382

11483
t1 = time.time()
115-
outs = self.net.run(None, {self.net.get_inputs()[0].name: blob})[0].squeeze(axis=0)
84+
outs = self.net.run(None, {self.net.get_inputs()[0].name: blob})[0]
11685
cost_time = time.time() - t1
11786
print(outs.shape)
118-
row_ind = 0
119-
for i in range(self.nl):
120-
h, w = int(self.input_shape[0] / self.stride[i]), int(self.input_shape[1] / self.stride[i])
121-
length = int(self.na * h * w)
122-
if self.grid[i].shape[2:4] != (h, w):
123-
self.grid[i] = self._make_grid(w, h)
124-
125-
outs[row_ind:row_ind + length, 0:2] = (outs[row_ind:row_ind + length, 0:2] * 2. - 0.5 + np.tile(
126-
self.grid[i], (self.na, 1))) * int(self.stride[i])
127-
outs[row_ind:row_ind + length, 2:4] = (outs[row_ind:row_ind + length, 2:4] * 2) ** 2 * np.repeat(
128-
self.anchor_grid[i], h * w, axis=0)
129-
row_ind += length
87+
13088
srcimg = self.postprocess(srcimg, outs, (newh, neww, top, left))
13189
infer_time = 'Inference Time: ' + str(int(cost_time * 1000)) + 'ms'
13290
cv2.putText(srcimg, infer_time, (5, 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, (0, 0, 0), thickness=1)
@@ -135,8 +93,8 @@ def detect(self, srcimg):
13593

13694
if __name__ == '__main__':
13795
parser = argparse.ArgumentParser()
138-
parser.add_argument('--imgpath', type=str, default='../sample/horse.jpg', help="image path")
139-
parser.add_argument('--modelpath', type=str, default='../weights/v5Lite-e.onnx', help="onnx filepath")
96+
parser.add_argument('--imgpath', type=str, default='./000000001000.jpg', help="image path")
97+
parser.add_argument('--modelpath', type=str, default='./v5lite-e_end2end.onnx', help="onnx filepath")
14098
parser.add_argument('--classfile', type=str, default='coco.names', help="classname filepath")
14199
parser.add_argument('--confThreshold', default=0.5, type=float, help='class confidence')
142100
parser.add_argument('--nmsThreshold', default=0.6, type=float, help='nms iou thresh')
@@ -146,9 +104,4 @@ def detect(self, srcimg):
146104
net = yolov5_lite(args.modelpath, args.classfile, confThreshold=args.confThreshold, nmsThreshold=args.nmsThreshold)
147105
srcimg = net.detect(srcimg.copy())
148106

149-
winName = 'Deep learning object detection in onnxruntime'
150-
cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
151-
cv2.imshow(winName, srcimg)
152-
cv2.waitKey(0)
153-
# cv2.imwrite('save.jpg', srcimg )
154-
cv2.destroyAllWindows()
107+
cv2.imwrite('save.jpg', srcimg )

0 commit comments

Comments
 (0)