-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbabi_process.py
130 lines (95 loc) · 3.68 KB
/
babi_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#https://arxiv.org/abs/1503.08895 MemN2N
import numpy as np
import re
import os
from collections import deque
def data_get(data_path, data_num=1, dataset='train', memory_capacity=50): #데이터 읽어서 패딩, 단어->숫자화 진행
filename = list(os.walk(data_path))[0][2]
result = []
for files in filename:
if (dataset in files) and (data_num == int(files.split('_')[0][2:])): # train or test
with open(data_path+files, 'r') as o:
for i in o:
if i.split()[0] == '1':
story = deque(maxlen=memory_capacity) # maximum 50개 문장
i = re.sub("[^a-zA-Z?, ]+", '', i).lower()[1:] #alphabet, '?', ',', ' '만 남기고 나머지 제거, 소문자화, 첫공백 제거(slicing)
if '?' in i: #question 문장 만나면.
i = i.split("?")
question = i[0].split() # ex) ['is','bill','in','the','room']
answer = i[1].strip() # ex) 'no'
# ',' 기준으로 분할하고 정렬해서 다시 합치는 이유는 n,e e,n 처럼 같은 의미를 다르게 표현하는 경우가 있기 때문임.
answer = [','.join(sorted(answer.split(',')))] # ex) ['no']
#answer = [answer]
sqa = [list(story.copy()), question.copy(), answer.copy()] #Story Question Answer
result.append(sqa)
else:
i = i.split() # ex) ['mary', 'is', 'in', 'the', 'school']
story.append(i)
return result
def get_word_dict_and_maximum_word_in_sentence(dataset):
word_dict = {}
rev_word_dict = {}
maximum_word_in_sentence = 0
count = 0
for data in dataset: # (1번 파일 데이터 ~ 20번 파일 데이터) * 3 : train20개, valid20개, test20개.
for story, question, answer in data:
### story ###
for s_sentence in story:
maximum_word_in_sentence = max(maximum_word_in_sentence, len(s_sentence))
for word in s_sentence:
if word not in word_dict:
word_dict[word] = count
rev_word_dict[count] = word
count += 1
### question ###
maximum_word_in_sentence = max(maximum_word_in_sentence, len(question))
for word in question:
if word not in word_dict:
word_dict[word] = count
rev_word_dict[count] = word
count += 1
### answer ###
word = answer[0]
if word not in word_dict:
word_dict[word] = count
rev_word_dict[count] = word
count += 1
word_dict['pad'] = -1
rev_word_dict[-1] = 'pad'
return word_dict, rev_word_dict, maximum_word_in_sentence
def train_vali_split(data, vali_ratio):
train = []
vali = []
for task_data in data:
vali.append(task_data[:int(len(task_data)*vali_ratio)])
train.append(task_data[int(len(task_data)*vali_ratio):])
return train, vali
def data_to_vector(data, word_dict, maximum_word_in_sentence, memory_capacity=50):
result = []
for task_data in data:
task = []
for story, question, answer in task_data:
sentence_number = np.arange(1, len(story)+1)
### story ###
s_vector = []
for s_sentence in story:
temp = []
temp.extend( [word_dict[word] for word in s_sentence] )
temp.extend([-1] * ((maximum_word_in_sentence)-len(temp)))
s_vector.append(temp)
s_vector.extend([[-1] * (maximum_word_in_sentence)] * (memory_capacity-len(s_vector)) )
### question ###
q_vector = [word_dict[word] for word in question]
q_vector.extend([-1]* ((maximum_word_in_sentence)-len(q_vector)))
### answer ###
a_vector = [word_dict[answer[0]]]
### task ###
task.append([s_vector, q_vector, a_vector])
### result ###
result.append(task)
return result
def merge_tasks(data):
merge = []
for task_data in data:
merge.extend(task_data)
return merge