-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathcnnmatching.py
executable file
·103 lines (86 loc) · 3.11 KB
/
cnnmatching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import cv2
import numpy as np
import imageio
import plotmatch
from lib.cnn_feature import cnn_feature_extract
import matplotlib.pyplot as plt
import time
from skimage import measure
from skimage import transform
#time count
start = time.perf_counter()
_RESIDUAL_THRESHOLD = 30
#Test1nThbg6kXUpJWGl7E1IGOCspRomTxdCARLviKw6E5SY8
imgfile1 = 'df-ms-data/1/df-googleearth-1k-20091227.jpg'
imgfile2 = 'df-ms-data/1/df-googleearth-1k-20181029.jpg'
imgfile1 = 'df-ms-data/1/df-uav-sar-500.jpg'
start = time.perf_counter()
# read left image
image1 = imageio.imread(imgfile1)
image2 = imageio.imread(imgfile2)
print('read image time is %6.3f' % (time.perf_counter() - start))
start0 = time.perf_counter()
kps_left, sco_left, des_left = cnn_feature_extract(image1, nfeatures = -1)
kps_right, sco_right, des_right = cnn_feature_extract(image2, nfeatures = -1)
print('Feature_extract time is %6.3f, left: %6.3f,right %6.3f' % ((time.perf_counter() - start), len(kps_left), len(kps_right)))
start = time.perf_counter()
#Flann特征匹配
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=40)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(des_left, des_right, k=2)
goodMatch = []
locations_1_to_use = []
locations_2_to_use = []
# 匹配对筛选
min_dist = 1000
max_dist = 0
disdif_avg = 0
# 统计平均距离差
for m, n in matches:
disdif_avg += n.distance - m.distance
disdif_avg = disdif_avg / len(matches)
for m, n in matches:
#自适应阈值
if n.distance > m.distance + disdif_avg:
goodMatch.append(m)
p2 = cv2.KeyPoint(kps_right[m.trainIdx][0], kps_right[m.trainIdx][1], 1)
p1 = cv2.KeyPoint(kps_left[m.queryIdx][0], kps_left[m.queryIdx][1], 1)
locations_1_to_use.append([p1.pt[0], p1.pt[1]])
locations_2_to_use.append([p2.pt[0], p2.pt[1]])
#goodMatch = sorted(goodMatch, key=lambda x: x.distance)
print('match num is %d' % len(goodMatch))
locations_1_to_use = np.array(locations_1_to_use)
locations_2_to_use = np.array(locations_2_to_use)
# Perform geometric verification using RANSAC.
_, inliers = measure.ransac((locations_1_to_use, locations_2_to_use),
transform.AffineTransform,
min_samples=3,
residual_threshold=_RESIDUAL_THRESHOLD,
max_trials=1000)
print('Found %d inliers' % sum(inliers))
inlier_idxs = np.nonzero(inliers)[0]
#最终匹配结果
matches = np.column_stack((inlier_idxs, inlier_idxs))
print('whole time is %6.3f' % (time.perf_counter() - start0))
# Visualize correspondences, and save to file.
#1 绘制匹配连线
plt.rcParams['savefig.dpi'] = 100 #图片像素
plt.rcParams['figure.dpi'] = 100 #分辨率
plt.rcParams['figure.figsize'] = (4.0, 3.0) # 设置figure_size尺寸
_, ax = plt.subplots()
plotmatch.plot_matches(
ax,
image1,
image2,
locations_1_to_use,
locations_2_to_use,
np.column_stack((inlier_idxs, inlier_idxs)),
plot_matche_points = False,
matchline = True,
matchlinewidth = 0.3)
ax.axis('off')
ax.set_title('')
plt.show()