Skip to content

Commit 5b56725

Browse files
authored
Update check.py
1 parent ee47923 commit 5b56725

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

scripts/check.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def non_max_suppression_end2end(prediction, conf_thres=0.25, iou_thres=0.45, cla
4949

5050
def non_max_suppression_mnne(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, nc=None):
5151
output = []
52-
print(prediction.shape)
5352

5453
xc = prediction[:, 4] > conf_thres # candidates
5554
output = prediction[xc]
@@ -58,9 +57,6 @@ def non_max_suppression_mnne(prediction, conf_thres=0.25, iou_thres=0.45, classe
5857
print(output.shape)
5958

6059
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
61-
print(output[i].shape)
62-
63-
print(output[i])
6460

6561
return output[i]
6662

@@ -87,7 +83,6 @@ def non_max_suppression_mnnd(prediction, conf_thres=0.25, iou_thres=0.45, classe
8783
boxes, scores = x[:, :4] +c , x[:, 4] # boxes (offset by class), scores
8884
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
8985
output[xi] = x[i].view(-1, 6)
90-
print(output[0])
9186

9287
return output[0]
9388

@@ -118,7 +113,15 @@ def process(weight_path, img_path):
118113
out = sess.run(['outputs'], {'images': image.numpy()})[0]
119114
out = torch.from_numpy(out)
120115

121-
output = non_max_suppression_end2end(out, 0.50, 0.50, nc=1)
116+
# 如果使用的是end2end的导出方式,则使用以下后处理
117+
# output = non_max_suppression_end2end(out, 0.50, 0.50, nc=80)
118+
119+
# 如果使用的是mnnd的导出方式,则使用以下后处理
120+
output = non_max_suppression_mnnd(out, 0.50, 0.50, nc=80)
121+
122+
# 如果使用的是mnne的导出方式,则使用以下后处理
123+
# output = non_max_suppression_mnne(out, 0.50, 0.50, nc=80)
124+
122125
nimg = image[0].permute(1, 2, 0) * 255
123126
nimg = nimg.cpu().numpy().astype(np.uint8)
124127
nimg = cv2.cvtColor(nimg, cv2.COLOR_BGR2RGB)

0 commit comments

Comments
 (0)