-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
107 lines (76 loc) · 3.1 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import pandas as pd
import math
import tqdm
#import gpytorch
# from matplotlib import pyplot as plt
from itertools import cycle
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Bio import SeqIO
from Bio.Seq import Seq
import time
import sklearn
import argparse
from sklearn.metrics import precision_recall_curve, roc_curve, auc, confusion_matrix
from sklearn.model_selection import KFold
from seq_load_one_hot import *
from model_one_hot import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wordvec_len = 4
FC_DROPOUT = 0.5
def predict(model, x):
model.eval() #evaluation mode do not use drop out
fx = model.forward(x)
return fx
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-predict_fa", "--predict_fasta", action="store", dest='predict_fa', required=True,
help="predict fasta file")
parser.add_argument("-model_path", "--model_path", action="store", dest='model_path', required=True,
help="model_path")
parser.add_argument("-outfile", "--outfile", action="store", dest='outfile', required=True,
help="outfile name")
args = parser.parse_args()
predict_file = args.predict_fa
model_path = args.model_path
if model_path[-1] == '/':
model_path = model_path[:-1]
checkpoint = torch.load(model_path + '/' + 'checkpoint_ATCG.pth.tar', map_location=torch.device('cpu'))
model = model = CNN51_RNN(FC_DROPOUT)
model.load_state_dict(checkpoint['state_dict'])
X_test, fa_header = load_data_bicoding_with_header(predict_file)
X_test=np.array(X_test)
#print(X_test.shape)
X_test = X_test.reshape(X_test.shape[0], int(X_test.shape[1] / wordvec_len), wordvec_len)
X_test = torch.from_numpy(X_test).float()
batch_size = 256
i = 0
N = X_test.shape[0]
y_pred_test = []
y_pred_prob_test = []
with open(args.outfile, 'w') as fw:
while i + batch_size < N:
x_batch = X_test[i:i + batch_size]
header_batch = fa_header[i:i + batch_size]
fx = predict(model, x_batch)
#y_pred = fx.cpu().data.numpy().argmax(axis=1)
prob_data = F.log_softmax(fx, dim=1).cpu().data.numpy()
# prob_data = torch.sigmoid(fx).data.numpy()
for m in range(len(prob_data)):
# y_pred_prob_test.append(np.exp(prob_data)[m][1])
fw.write(header_batch[m] + '\t' + str(np.exp(prob_data)[m][1]) + '\n')
#y_pred_test += list(y_pred)
i += batch_size
x_batch = X_test[i:N]
header_batch = fa_header[i:N]
fx = predict(model, x_batch)
# y_pred = fx.cpu().data.numpy().argmax(axis=1)
prob_data = F.log_softmax(fx, dim=1).cpu().data.numpy()
# prob_data = torch.sigmoid(fx).data.numpy()
for m in range(len(prob_data)):
# y_pred_prob_test.append(np.exp(prob_data)[m][1])
fw.write(header_batch[m] + '\t' + str(np.exp(prob_data)[m][1]) + '\n')