-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
40 lines (33 loc) · 1.07 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import h5py
import yaoai
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import Masking
def main():
sentence_length = 40
dataX, dataY, story, word_to_int, int_to_word = yaoai.setup(sentence_length, './input.txt')
model = Sequential([
Masking(mask_value=0.0, input_shape=(dataX.shape[1:])),
LSTM(256, return_sequences=True),
Dropout(0.2),
LSTM(256, return_sequences=True),
Dropout(0.2),
Dense(dataY.shape[2], activation='softmax'),
])
model.load_weights('./model.hdf5')
model.compile(loss='categorical_crossentropy', optimizer='adam')
model_output = model.predict(dataX)
file_path = './output.txt'
file = open(file_path, 'w')
for x in model_output:
output_data = []
for y in x:
output_data.append(int_to_word[np.argmax(y)])
output_data.append('\n\n')
file.writelines(output_data)
file.close()
if __name__ == '__main__':
main()