Skip to content

Commit 6f55f36

Browse files
authoredDec 4, 2017
Add files via upload
1 parent e57250c commit 6f55f36

File tree

2 files changed

+101
-3
lines changed

2 files changed

+101
-3
lines changed
 

‎picross_dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from picross_generator import generator as gen
66

77
def load_data(row, col, visible, num):
8-
x = np.zeros((num, row, col), dtype='int16')
9-
y = np.zeros((num, 2, row, col), dtype='int16')
8+
x = np.zeros((num, 2, row, col), dtype='int16')
9+
y = np.zeros((num, row, col), dtype='int16')
1010

1111
for i in range(num):
12-
x[i], y[i, 0], y[i, 1] = gen(row, col, visible)
12+
y[i], x[i, 0], x[i, 1] = gen(row, col, visible)
1313

1414
return x, y
1515

‎picross_rnn.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#picross_rnn.py
2+
#copyright: fiorezhang@sina.com
3+
4+
import numpy as np
5+
from keras.models import Sequential
6+
from keras import layers
7+
from picross_dataset import load_data as load
8+
9+
#global variables for dataset
10+
SIZE_ROW = 3
11+
SIZE_COL = 5
12+
VISIBLE = 0.6
13+
TRAIN_SIZE = 5000
14+
TEST_SIZE = 500
15+
16+
#generate data from picross generator/dataset functions
17+
print('-'*50)
18+
print('Generating data...')
19+
print('Picross size: ', SIZE_ROW, 'x', SIZE_COL, ', visible: ', VISIBLE*100, '%')
20+
print('Train samples: ', TRAIN_SIZE, ', test samples: ', TEST_SIZE)
21+
x, y = load(SIZE_ROW, SIZE_COL, VISIBLE, TRAIN_SIZE+TEST_SIZE)
22+
x_train, x_test = x[:TRAIN_SIZE], x[TRAIN_SIZE:]
23+
y_train, y_test = y[:TRAIN_SIZE], y[TRAIN_SIZE:]
24+
print(x.shape)
25+
print(y.shape)
26+
print(x_train.shape)
27+
print(y_train.shape)
28+
print(x[0])
29+
print(y[0])
30+
31+
#set parameters for RNN modle
32+
RNN=layers.SimpleRNN
33+
HIDDEN_SIZE = 128
34+
BATCH_SIZE = 128
35+
LAYERS = 1
36+
ITERATION = 100
37+
EPOCHS = 10
38+
39+
#build the model
40+
print('-'*50)
41+
print('Building model...')
42+
model = Sequential()
43+
#reshape input, connect 2 matrix and flatten
44+
model.add(layers.Reshape((2*SIZE_ROW*SIZE_COL, 1), input_shape=(2, SIZE_ROW, SIZE_COL)))
45+
#print(model.output_shape)
46+
#import data to a RNN
47+
model.add(RNN(HIDDEN_SIZE, dropout=0.1))
48+
#print(model.output_shape)
49+
#repeat r*c times for output
50+
model.add(layers.RepeatVector(SIZE_ROW*SIZE_COL))
51+
#print(model.output_shape)
52+
#rnn again, set return_sequences to output for all time pieces
53+
for _ in range(LAYERS):
54+
model.add(RNN(HIDDEN_SIZE, return_sequences=True, dropout=0.1))
55+
#print(model.output_shape)
56+
#flatten, 3D->2D
57+
model.add(layers.TimeDistributed(layers.Dense(1)))
58+
#print(model.output_shape)
59+
#finally binary cross entropy
60+
model.add(layers.Activation('sigmoid'))
61+
#print(model.output_shape)
62+
#model.add(layers.Reshape((SIZE_ROW, SIZE_COL)))
63+
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
64+
#print(model.output_shape)
65+
model.summary()
66+
67+
#train the model and print information during the process
68+
print('-'*50)
69+
print('Training...')
70+
for iteration in range(1, ITERATION):
71+
print()
72+
print('-'*50)
73+
print('Iteration', iteration)
74+
model.fit(x_train, y_train.reshape(-1, SIZE_ROW*SIZE_COL, 1),
75+
batch_size=BATCH_SIZE,
76+
epochs=EPOCHS,
77+
validation_data=(x_test, y_test.reshape(-1, SIZE_ROW*SIZE_COL, 1)))
78+
#show result in the middle
79+
for i in range(1):
80+
ind = np.random.randint(0, len(x_test))
81+
rowx, rowy = x_test[np.array([ind])], y_test[np.array([ind])]
82+
preds = model.predict_classes(rowx, verbose=0)
83+
question = rowx[0]
84+
#print(question.shape)
85+
correct = rowy[0]
86+
#print(correct.shape)
87+
#print(preds.shape)
88+
guess = preds[0].reshape(SIZE_ROW, SIZE_COL)
89+
print('Q','- '*25)
90+
print(question)
91+
print('A','- '*25)
92+
print(correct)
93+
if (correct == guess).all():
94+
print('Y','- '*25)
95+
else:
96+
print('N','- '*25)
97+
print(guess)
98+
print('- '*25)

0 commit comments

Comments
 (0)
Please sign in to comment.