-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmindPred.py
101 lines (77 loc) · 2.92 KB
/
mindPred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import numpy as np
import operator
import torch
import torch.nn as nn
ACTION = ['go', 'none']
class net(nn.Module):
def __init__(self, output_size=2):
super(net, self).__init__()
# self.hidden_layer_size = hidden_layer_size
# self.lstm = nn.LSTM(input_size, hidden_layer_size, batch_first=True)
# self.rn = nn.RNN(13, 9, 4)
self.conv1_1 = nn.Conv1d(8, 6, 2, stride=1)
self.maxpl_1 = nn.MaxPool1d(2, stride=1)
self.conv1_2 = nn.Conv1d(6, 6, 3, stride=1)
self.maxpl_2 = nn.MaxPool1d(2, stride=2)
self.o2s = nn.Linear(168, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input_seq, hidden):
# output = self.rn(input_seq, hidden)
output = self.conv1_1(input_seq)
output = self.maxpl_1(output)
output = self.conv1_2(output)
output = self.maxpl_2(output)
# output, hidden = self.rn(output, hidden)
output = output.reshape(-1, 168)
# output = output.reshape(-1, 16 * 29)
output = self.o2s(output)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(4, 16, 9)
def getData(datapath):
data = []
for session_file in os.listdir(datapath):
filepath = os.path.join(datapath, session_file)
file = np.load(filepath)
for idx, line in enumerate(file):
data.append(line)
print('Data Loaded')
return data
def prediction(data):
argmax_dict = {0: 0, 1: 0, 2: 0}
for scq in data:
input = torch.FloatTensor(scq)
input = input.unsqueeze(0)
hiden = 0
output, hiden = model(input, hiden)
value = output.cpu().detach().numpy().argmax()
argmax_dict[value] += 1
total = argmax_dict[0] + argmax_dict[1] + argmax_dict[2]
pred = max(argmax_dict.items(), key=operator.itemgetter(1))[0]
print(ACTION[pred], f'{argmax_dict[0] * 100 / total:.2f}% {argmax_dict[1] * 100 / total:.2f}% {argmax_dict[2] * 100 / total:.2f}% ')
def real_time_prediction(model, data):
argmax_dict = {0: 0, 1: 0, 2: 0}
for scq in data:
input = torch.FloatTensor(scq)
input = input.unsqueeze(0)
hiden = 0
output, hiden = model(input, hiden)
value = output.cpu().detach().numpy().argmax()
argmax_dict[value] += 1
total = argmax_dict[0] + argmax_dict[1] + argmax_dict[2]
pred = max(argmax_dict.items(), key=operator.itemgetter(1))[0]
print(ACTION[pred], f'{argmax_dict[0] * 100 / total:.2f}% {argmax_dict[1] * 100 / total:.2f}% {argmax_dict[2] * 100 / total:.2f}% ')
return ACTION[pred]
def init(path="./acc100.00.pt"):
model = net(output_size=2)
model.load_state_dict(torch.load(path))
model.eval()
return model
if __name__ == "__main__":
model = init()
go = getData("test/go")
none = getData("test/none")
prediction(go)
prediction(none)