-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
94 lines (89 loc) · 4.08 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import random
import pickle
import json
from model_dy import *
from bio_utils import *
from nltk.translate.bleu_score import corpus_bleu
if __name__ == '__main__':
random.seed(1)
parser = argparse.ArgumentParser()
parser.add_argument('model')
parser.add_argument('datadir')
parser.add_argument('dev_datadir')
args = parser.parse_args()
input_lang, pl1, char, raw_train = prepare_data(args.datadir)
input_lang, pl1, char, rule_lang, raw_train = parse_json_data(input_lang, pl1, char, raw_train)
input2_lang, pl2, char2, raw_test = prepare_data(args.dev_datadir, "valids.json")
# raw_test = prepare_test_data(args.dev_datadir)
model = LSTMLM.load(args.model)
test = []
i = j = 0
for datapoint in raw_test:
if datapoint[4]:
try:
i += 1
test.append(([input_lang.word2index[w] if w in input_lang.word2index else 2 for w in datapoint[0]]+[1],
datapoint[1],
datapoint[2],
datapoint[3],
input_lang.label2id[datapoint[4]],
[pl1.word2index[p] if p in pl1.word2index else 2 for p in datapoint[-2]]+[0],
[[char.word2index[c] if c in char.word2index else 2 for c in w] for w in datapoint[0]+["EOS"]],
datapoint[-1]))
except:
print (datapoint)
else:
# try:
j += 1
test.append(([input_lang.word2index[w] if w in input_lang.word2index else 2 for w in datapoint[0]]+[1],
datapoint[1],
datapoint[2],
datapoint[3], 0,
[pl1.word2index[p] if p in pl1.word2index else 2 for p in datapoint[-2]]+[0],
[[char.word2index[c] if c in char.word2index else 2 for c in w] for w in datapoint[0]+["EOS"]],
datapoint[-1]))
# except:
# print (datapoint)
print(i,j)
predict = 0.0
label_correct = 0.0
trigger_correct = 0.0
both_correct = 0.0
references = []
candidates = []
for datapoint in test:
sentence = datapoint[0]
eid = datapoint[1]
entity = datapoint[2]
pos = datapoint[-3]
chars = datapoint[-2]
rule = (model.get_pred(sentence, pos,chars, entity))
# pred_trigger = attention.index(max(attention)) if attention.index(max(attention)) != len(attention)-1 else -1
# if pred_label != 0:
# predict += 1.0
# if pred_trigger == datapoint[3]:
# trigger_correct += 1.0
# if pred_label == datapoint[4]:
# label_correct += 1.0
# if pred_trigger == datapoint[3] and pred_label == datapoint[4]:
# both_correct += 1.0
# with open("attention%d"%(i/10), "a") as f:
# f.write(' '.join([input_lang.index2word[sentence[i1]]+" %.4f"%attention[i1] for i1 in range(0, len(sentence))]))
# t = input_lang.index2word[sentence[datapoint[3]]] if datapoint[3]!=-1 else "None"
# f.write("\t"+' '.join([input_lang.index2word[sentence[e]] for e in entity]))
# f.write("\ttrigger: %s pred_trigger: %s\n"%(t, input_lang.index2word[sentence[pred_trigger]]))
# if datapoint[-1]:
# references.append([datapoint[-1]])
# candidates.append([rule_lang.index2word[p] for p in rule])
with open("rules-1", "a") as f:
f.write(' '.join([rule_lang.index2word[p] for p in rule])+"\n")
# with open("result", "w") as f:
# f.write("predict: %d, trigger correct: %d, label correct: %d, both correct: %d\n"
# %(predict, trigger_correct, label_correct, both_correct))
# precision = label_correct/predict if predict !=0 else 0
# recall = label_correct/197.0
# f1 = (2*precision*recall/(precision+recall)) if (precision+recall) != 0 else 0
# bleu = corpus_bleu(references, candidates)
# f.write("precision: %.4f, recall: %.4f, f1: %.4f bleu: %.4f"
# %(precision, recall, f1, bleu))