-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtest.py
131 lines (95 loc) · 3.66 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
'''
Created on 2019-06-30
@author: chenjun2hao
'''
import os
import cv2
import glob
import numpy as np
import torch
import argparse
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
from nms import get_boxes
from tools.models import ModelResNetSep2, OwnModel
import tools.net_utils as net_utils
from src.utils import strLabelConverter, alphabet
from tools.ocr_utils import ocr_image, align_ocr
from tools.data_gen import draw_box_points
def resize_image(im, max_size = 1585152, scale_up=True):
if scale_up:
image_size = [im.shape[1] * 3 // 32 * 32, im.shape[0] * 3 // 32 * 32]
else:
image_size = [im.shape[1] // 32 * 32, im.shape[0] // 32 * 32]
while image_size[0] * image_size[1] > max_size:
image_size[0] /= 1.2
image_size[1] /= 1.2
image_size[0] = int(image_size[0] // 32) * 32
image_size[1] = int(image_size[1] // 32) * 32
resize_h = int(image_size[1])
resize_w = int(image_size[0])
scaled = cv2.resize(im, dsize=(resize_w, resize_h))
return scaled, (resize_h, resize_w)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-cuda', type=int, default=1)
parser.add_argument('-model', default='./weights/FOTS_280000.h5')
# parser.add_argument('-model', default='./weights/e2e-mlt.h5')
parser.add_argument('-segm_thresh', default=0.5)
parser.add_argument('-test_folder', default=r'./data/example_image/')
parser.add_argument('-output', default='./data/ICDAR2015')
font2 = ImageFont.truetype("./tools/Arial-Unicode-Regular.ttf", 18)
args = parser.parse_args()
# net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
net_utils.load_net(args.model, net)
net = net.eval()
converter = strLabelConverter(alphabet)
if args.cuda:
print('Using cuda ...')
net = net.cuda()
test_path = os.path.realpath(args.test_folder)
test_path = test_path + '/*.jpg'
imagelist = glob.glob(test_path)
with torch.no_grad():
for path in imagelist:
im = cv2.imread(path)
im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
images = np.asarray([im_resized], dtype=np.float)
images /= 128
images -= 1
im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=args.cuda)
seg_pred, rboxs, angle_pred, features = net(im_data)
rbox = rboxs[0].data.cpu()[0].numpy() # 转变成h,w,c
rbox = rbox.swapaxes(0, 1)
rbox = rbox.swapaxes(1, 2)
angle_pred = angle_pred[0].data.cpu()[0].numpy()
segm = seg_pred[0].data.cpu()[0].numpy()
segm = segm.squeeze(0)
draw2 = np.copy(im_resized)
boxes = get_boxes(segm, rbox, angle_pred, args.segm_thresh)
img = Image.fromarray(draw2)
draw = ImageDraw.Draw(img)
out_boxes = []
for box in boxes:
pts = box[0:8]
pts = pts.reshape(4, -1)
# det_text, conf, dec_s = ocr_image(net, codec, im_data, box)
det_text, conf, dec_s = align_ocr(net, converter, im_data, box, features, debug=0)
if len(det_text) == 0:
continue
width, height = draw.textsize(det_text, font=font2)
center = [box[0], box[1]]
draw.text((center[0], center[1]), det_text, fill = (0,255,0),font=font2)
out_boxes.append(box)
print(det_text)
im = np.array(img)
for box in out_boxes:
pts = box[0:8]
pts = pts.reshape(4, -1)
draw_box_points(im, pts, color=(0, 255, 0), thickness=1)
cv2.imshow('img', im)
basename = os.path.basename(path)
cv2.imwrite(os.path.join(args.output, basename), im)
cv2.waitKey(1000)