forked from mpskex/Convolutional-Pose-Machine-tf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalPCK.py
executable file
·127 lines (108 loc) · 4.19 KB
/
valPCK.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#coding: utf-8
import cv2
import time
import numpy as np
import CPM
import predict
import Global
import datagen_mpii as datagen
import matplotlib.pyplot as plt
"""
PCKh Validation
For Single Person Pose Estimation
Human Pose Estimation Project in Lab of IP
Author: Liu Fangrui aka mpsk
Beijing University of Technology
College of Computer Science & Technology
Experimental Code
!!DO NOT USE IT AS DEPLOYMENT!!
ATTENTION:
A pre-gen Validation set is needed to do validation seperately
"""
def visualize_accuracy(a, start=0.01, end=0.5, showlist=None, resolution=100):
""" Compute the Accuracy in rate
Args:
a: distance matrix like err in compute_distance
showlist: list of joints to be shown
resolutional: determine the
"""
if showlist is None:
showlist = range(0, len(Global.joint_list))
else:
show_list = showlist
plt.title("CPM PCKh benchmark on MPII")
plt.xlabel("Normalized distance")
plt.ylabel("Accuracy")
dists = np.linspace(start, end, resolution)
re = np.zeros((dists.shape[0], a.shape[1]), np.float)
av = np.zeros((dists.shape[0]), np.float)
for ridx in range(dists.shape[0]):
print "[*]\tProcessing Result in normalized distance of", dists[ridx]
for j in range(a.shape[1]):
condition = a[:,j] <= dists[ridx]
re[ridx, j] = len(np.extract(condition, a[:,j]))/float(a.shape[0])
print "[*]\t", Global.joint_list[j], " Accuracy :\t\t", re[ridx, j]
av[ridx] = np.average(re[ridx])
print "[*]\t\tTOTAL Accuracy :\t\t", av[ridx]
for j in show_list:
plt.plot(dists, re[:, j], label=Global.joint_list[j], linewidth = 2.0)
plt.plot(dists, av, label="Average",linewidth = 3.0)
plt.legend(loc='upper left')
plt.grid()
plt.show()
def compute_distance(model, dataset, metric='PCKh', debug=False):
"""
Args:
model: model to load
dataset: dataset to generate image & ground truth
Return:
An normalized distance matrix with shape of num_of_img x joint_num
"""
if metric == 'PCKh':
normJ_a = 9
normJ_b = 8
elif metric == 'PCK':
normJ_a = 12
normJ_b = 3
else:
raise ValueError
err = np.zeros((len(dataset.valid_set), len(Global.joint_list)), np.float)
print "[*]\tCreated error matrix with shape of ", err.shape
paral = 16
if debug:
paral = 2
for _iter in range(len(dataset.valid_set)/paral):
im_list = []
j_list = []
w_list = []
for n in range(paral):
img, new_j, w, joint_full, max_l = dataset.getSample(sample=dataset.valid_set[_iter*paral+n])
im_list.append(img)
j_list.append(new_j)
w_list.append(w)
w = np.array(w_list)
j_gt = np.array(j_list)
# estimate by network
j_dt,_ = predict.predict(im_list, model=model, thresh=0.0)
w = np.transpose(np.hstack((np.expand_dims(w,1), np.expand_dims(w,1))), axes=[0,2,1])
for n in range(len(Global.joint_list)):
err[_iter*paral+n] = np.linalg.norm(w[n]*(j_gt[n,:,:]-j_dt[n,:,-1::-1]),axis=1) / np.linalg.norm(j_gt[n,normJ_a,:]-j_gt[n,normJ_b,:], axis=0)
print "[*]\tTemp Error is ", np.average(err[_iter*paral:_iter*paral+paral], axis=0)
if debug:
break
aver_err = np.average(err)
print "[*]\tAverage PCKh Normalised distance is ", aver_err
return err
if __name__ == '__main__':
print('--Creating Dataset')
dataset = datagen.DataGenerator(Global.joint_list, Global.IMG_ROOT, Global.training_txt_file, remove_joints=None, in_size=Global.INPUT_SIZE)
dataset._create_train_table()
dataset.valid_set = np.load('Dataset-Validation-Set.npy')
model = CPM.CPM(pretrained_model=None,
# enable if you want to do it in cpu only mode
#cpu_only=False,
training=False)
model.BuildModel()
model.restore_sess('model/model.ckpt-99')
dist = compute_distance(model, dataset, metric='PCKh', debug=True)
visualize_accuracy(dist, start=0.01, end=0.5, showlist=[0,1,4,5,10,11,14,15])