5
5
import onnxruntime as ort
6
6
7
7
class yolov5_lite ():
8
- def __init__ (self , model_pb_path , label_path , confThreshold = 0.5 , nmsThreshold = 0.5 , objThreshold = 0.5 ):
8
+ def __init__ (self , model_pb_path , label_path , confThreshold = 0.5 , nmsThreshold = 0.5 ):
9
9
so = ort .SessionOptions ()
10
10
so .log_severity_level = 3
11
11
self .net = ort .InferenceSession (model_pb_path , so )
12
12
self .classes = list (map (lambda x : x .strip (), open (label_path , 'r' ).readlines ()))
13
- self .num_classes = len (self .classes )
14
- anchors = [[10 , 13 , 16 , 30 , 33 , 23 ], [30 , 61 , 62 , 45 , 59 , 119 ], [116 , 90 , 156 , 198 , 373 , 326 ]]
15
- self .nl = len (anchors )
16
- self .na = len (anchors [0 ]) // 2
17
- self .no = self .num_classes + 5
18
- self .grid = [np .zeros (1 )] * self .nl
19
- self .stride = np .array ([8. , 16. , 32. ])
20
- self .anchor_grid = np .asarray (anchors , dtype = np .float32 ).reshape (self .nl , - 1 , 2 )
21
13
22
14
self .confThreshold = confThreshold
23
15
self .nmsThreshold = nmsThreshold
24
- self .objThreshold = objThreshold
25
16
self .input_shape = (self .net .get_inputs ()[0 ].shape [2 ], self .net .get_inputs ()[0 ].shape [3 ])
26
17
27
- def resize_image (self , srcimg , keep_ratio = True ):
18
+ def letterBox (self , srcimg , keep_ratio = True ):
28
19
top , left , newh , neww = 0 , 0 , self .input_shape [0 ], self .input_shape [1 ]
29
20
if keep_ratio and srcimg .shape [0 ] != srcimg .shape [1 ]:
30
21
hw_scale = srcimg .shape [0 ] / srcimg .shape [1 ]
@@ -43,90 +34,57 @@ def resize_image(self, srcimg, keep_ratio=True):
43
34
img = cv2 .resize (srcimg , self .input_shape , interpolation = cv2 .INTER_AREA )
44
35
return img , newh , neww , top , left
45
36
46
- def _make_grid (self , nx = 20 , ny = 20 ):
47
- xv , yv = np .meshgrid (np .arange (ny ), np .arange (nx ))
48
- return np .stack ((xv , yv ), 2 ).reshape ((- 1 , 2 )).astype (np .float32 )
49
-
50
37
def postprocess (self , frame , outs , pad_hw ):
51
38
newh , neww , padh , padw = pad_hw
52
39
frameHeight = frame .shape [0 ]
53
40
frameWidth = frame .shape [1 ]
54
41
ratioh , ratiow = frameHeight / newh , frameWidth / neww
55
- # Scan through all the bounding boxes output from the network and keep only the
56
- # ones with high confidence scores. Assign the box's class label as the class with the highest score.
57
42
classIds = []
58
43
confidences = []
59
- box_index = []
60
44
boxes = []
61
- outs = outs [outs [:, 4 ] > self .objThreshold ]
62
45
for detection in outs :
63
- scores = detection [5 :]
64
- classId = np .argmax (scores )
65
- confidence = scores [classId ]
66
- if confidence > self .confThreshold : # and detection[4] > self.objThreshold:
67
- center_x = int ((detection [0 ] - padw ) * ratiow )
68
- center_y = int ((detection [1 ] - padh ) * ratioh )
69
- width = int (detection [2 ] * ratiow )
70
- height = int (detection [3 ] * ratioh )
71
- left = int (center_x - width / 2 )
72
- top = int (center_y - height / 2 )
46
+ scores , classId = detection [4 ], detection [5 ]
47
+ if scores > self .confThreshold : # and detection[4] > self.objThreshold:
48
+ x1 = int ((detection [0 ] - padw ) * ratiow )
49
+ y1 = int ((detection [1 ] - padh ) * ratioh )
50
+ x2 = int ((detection [2 ] - padw ) * ratiow )
51
+ y2 = int ((detection [3 ] - padh ) * ratioh )
73
52
classIds .append (classId )
74
- confidences .append (float ( confidence ) )
75
- boxes .append ([left , top , width , height ])
53
+ confidences .append (scores )
54
+ boxes .append ([x1 , y1 , x2 , y2 ])
76
55
77
- # Perform non maximum suppression to eliminate redundant overlapping boxes with
78
- # lower confidences.
79
- print (boxes )
56
+ # # Perform non maximum suppression to eliminate redundant overlapping boxes with
57
+ # # lower confidences.
80
58
indices = cv2 .dnn .NMSBoxes (boxes , confidences , self .confThreshold , self .nmsThreshold )
81
59
82
- for i in indices :
83
- box_index .append (i [0 ])
84
-
85
- for i in box_index :
86
- box = boxes [i ]
87
- left = box [0 ]
88
- top = box [1 ]
89
- width = box [2 ]
90
- height = box [3 ]
91
- frame = self .drawPred (frame , classIds [i ], confidences [i ], left , top , left + width , top + height )
60
+ for ind in indices :
61
+ frame = self .drawPred (frame , classIds [ind ], confidences [ind ], boxes [ind ][0 ], boxes [ind ][1 ], boxes [ind ][2 ], boxes [ind ][3 ])
92
62
return frame
93
63
94
- def drawPred (self , frame , classId , conf , left , top , right , bottom ):
64
+ def drawPred (self , frame , classId , conf , x1 , y1 , x2 , y2 ):
95
65
# Draw a bounding box.
96
- cv2 .rectangle (frame , (left , top ), (right , bottom ), (0 , 0 , 255 ), thickness = 2 )
66
+ cv2 .rectangle (frame , (x1 , y1 ), (x2 , y2 ), (0 , 0 , 255 ), thickness = 2 )
97
67
98
68
label = '%.2f' % conf
99
- label = '%s:%s' % (self .classes [classId ], label )
69
+ text = '%s:%s' % (self .classes [int ( classId ) ], label )
100
70
101
71
# Display the label at the top of the bounding box
102
- labelSize , baseLine = cv2 .getTextSize (label , cv2 .FONT_HERSHEY_SIMPLEX , 0.5 , 1 )
103
- top = max (top , labelSize [1 ])
104
- # cv.rectangle(frame, (left, top - round(1.5 * labelSize[1])), (left + round(1.5 * labelSize[0]), top + baseLine), (255,255,255), cv.FILLED)
105
- cv2 .putText (frame , label , (left , top - 10 ), cv2 .FONT_HERSHEY_TRIPLEX , 0.5 , (0 , 255 , 0 ), thickness = 1 )
72
+ labelSize , baseLine = cv2 .getTextSize (text , cv2 .FONT_HERSHEY_SIMPLEX , 0.5 , 1 )
73
+ y1 = max (y1 , labelSize [1 ])
74
+ cv2 .putText (frame , text , (x1 , y1 - 10 ), cv2 .FONT_HERSHEY_TRIPLEX , 0.5 , (0 , 255 , 0 ), thickness = 1 )
106
75
return frame
107
76
108
77
def detect (self , srcimg ):
109
- img , newh , neww , top , left = self .resize_image (srcimg )
78
+ img , newh , neww , top , left = self .letterBox (srcimg )
110
79
img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB )
111
80
img = img .astype (np .float32 ) / 255.0
112
81
blob = np .expand_dims (np .transpose (img , (2 , 0 , 1 )), axis = 0 )
113
82
114
83
t1 = time .time ()
115
- outs = self .net .run (None , {self .net .get_inputs ()[0 ].name : blob })[0 ]. squeeze ( axis = 0 )
84
+ outs = self .net .run (None , {self .net .get_inputs ()[0 ].name : blob })[0 ]
116
85
cost_time = time .time () - t1
117
86
print (outs .shape )
118
- row_ind = 0
119
- for i in range (self .nl ):
120
- h , w = int (self .input_shape [0 ] / self .stride [i ]), int (self .input_shape [1 ] / self .stride [i ])
121
- length = int (self .na * h * w )
122
- if self .grid [i ].shape [2 :4 ] != (h , w ):
123
- self .grid [i ] = self ._make_grid (w , h )
124
-
125
- outs [row_ind :row_ind + length , 0 :2 ] = (outs [row_ind :row_ind + length , 0 :2 ] * 2. - 0.5 + np .tile (
126
- self .grid [i ], (self .na , 1 ))) * int (self .stride [i ])
127
- outs [row_ind :row_ind + length , 2 :4 ] = (outs [row_ind :row_ind + length , 2 :4 ] * 2 ) ** 2 * np .repeat (
128
- self .anchor_grid [i ], h * w , axis = 0 )
129
- row_ind += length
87
+
130
88
srcimg = self .postprocess (srcimg , outs , (newh , neww , top , left ))
131
89
infer_time = 'Inference Time: ' + str (int (cost_time * 1000 )) + 'ms'
132
90
cv2 .putText (srcimg , infer_time , (5 , 20 ), cv2 .FONT_HERSHEY_TRIPLEX , 0.5 , (0 , 0 , 0 ), thickness = 1 )
@@ -135,8 +93,8 @@ def detect(self, srcimg):
135
93
136
94
if __name__ == '__main__' :
137
95
parser = argparse .ArgumentParser ()
138
- parser .add_argument ('--imgpath' , type = str , default = '../sample/horse .jpg' , help = "image path" )
139
- parser .add_argument ('--modelpath' , type = str , default = '../weights/v5Lite-e .onnx' , help = "onnx filepath" )
96
+ parser .add_argument ('--imgpath' , type = str , default = './000000001000 .jpg' , help = "image path" )
97
+ parser .add_argument ('--modelpath' , type = str , default = './v5lite-e_end2end .onnx' , help = "onnx filepath" )
140
98
parser .add_argument ('--classfile' , type = str , default = 'coco.names' , help = "classname filepath" )
141
99
parser .add_argument ('--confThreshold' , default = 0.5 , type = float , help = 'class confidence' )
142
100
parser .add_argument ('--nmsThreshold' , default = 0.6 , type = float , help = 'nms iou thresh' )
@@ -146,9 +104,4 @@ def detect(self, srcimg):
146
104
net = yolov5_lite (args .modelpath , args .classfile , confThreshold = args .confThreshold , nmsThreshold = args .nmsThreshold )
147
105
srcimg = net .detect (srcimg .copy ())
148
106
149
- winName = 'Deep learning object detection in onnxruntime'
150
- cv2 .namedWindow (winName , cv2 .WINDOW_NORMAL )
151
- cv2 .imshow (winName , srcimg )
152
- cv2 .waitKey (0 )
153
- # cv2.imwrite('save.jpg', srcimg )
154
- cv2 .destroyAllWindows ()
107
+ cv2 .imwrite ('save.jpg' , srcimg )
0 commit comments