1
+ #picross_rnn.py
2
+ #copyright: fiorezhang@sina.com
3
+
4
+ import numpy as np
5
+ from keras .models import Sequential
6
+ from keras import layers
7
+ from picross_dataset import load_data as load
8
+
9
+ #global variables for dataset
10
+ SIZE_ROW = 3
11
+ SIZE_COL = 5
12
+ VISIBLE = 0.6
13
+ TRAIN_SIZE = 5000
14
+ TEST_SIZE = 500
15
+
16
+ #generate data from picross generator/dataset functions
17
+ print ('-' * 50 )
18
+ print ('Generating data...' )
19
+ print ('Picross size: ' , SIZE_ROW , 'x' , SIZE_COL , ', visible: ' , VISIBLE * 100 , '%' )
20
+ print ('Train samples: ' , TRAIN_SIZE , ', test samples: ' , TEST_SIZE )
21
+ x , y = load (SIZE_ROW , SIZE_COL , VISIBLE , TRAIN_SIZE + TEST_SIZE )
22
+ x_train , x_test = x [:TRAIN_SIZE ], x [TRAIN_SIZE :]
23
+ y_train , y_test = y [:TRAIN_SIZE ], y [TRAIN_SIZE :]
24
+ print (x .shape )
25
+ print (y .shape )
26
+ print (x_train .shape )
27
+ print (y_train .shape )
28
+ print (x [0 ])
29
+ print (y [0 ])
30
+
31
+ #set parameters for RNN modle
32
+ RNN = layers .SimpleRNN
33
+ HIDDEN_SIZE = 128
34
+ BATCH_SIZE = 128
35
+ LAYERS = 1
36
+ ITERATION = 100
37
+ EPOCHS = 10
38
+
39
+ #build the model
40
+ print ('-' * 50 )
41
+ print ('Building model...' )
42
+ model = Sequential ()
43
+ #reshape input, connect 2 matrix and flatten
44
+ model .add (layers .Reshape ((2 * SIZE_ROW * SIZE_COL , 1 ), input_shape = (2 , SIZE_ROW , SIZE_COL )))
45
+ #print(model.output_shape)
46
+ #import data to a RNN
47
+ model .add (RNN (HIDDEN_SIZE , dropout = 0.1 ))
48
+ #print(model.output_shape)
49
+ #repeat r*c times for output
50
+ model .add (layers .RepeatVector (SIZE_ROW * SIZE_COL ))
51
+ #print(model.output_shape)
52
+ #rnn again, set return_sequences to output for all time pieces
53
+ for _ in range (LAYERS ):
54
+ model .add (RNN (HIDDEN_SIZE , return_sequences = True , dropout = 0.1 ))
55
+ #print(model.output_shape)
56
+ #flatten, 3D->2D
57
+ model .add (layers .TimeDistributed (layers .Dense (1 )))
58
+ #print(model.output_shape)
59
+ #finally binary cross entropy
60
+ model .add (layers .Activation ('sigmoid' ))
61
+ #print(model.output_shape)
62
+ #model.add(layers.Reshape((SIZE_ROW, SIZE_COL)))
63
+ model .compile (loss = 'binary_crossentropy' , optimizer = 'rmsprop' , metrics = ['accuracy' ])
64
+ #print(model.output_shape)
65
+ model .summary ()
66
+
67
+ #train the model and print information during the process
68
+ print ('-' * 50 )
69
+ print ('Training...' )
70
+ for iteration in range (1 , ITERATION ):
71
+ print ()
72
+ print ('-' * 50 )
73
+ print ('Iteration' , iteration )
74
+ model .fit (x_train , y_train .reshape (- 1 , SIZE_ROW * SIZE_COL , 1 ),
75
+ batch_size = BATCH_SIZE ,
76
+ epochs = EPOCHS ,
77
+ validation_data = (x_test , y_test .reshape (- 1 , SIZE_ROW * SIZE_COL , 1 )))
78
+ #show result in the middle
79
+ for i in range (1 ):
80
+ ind = np .random .randint (0 , len (x_test ))
81
+ rowx , rowy = x_test [np .array ([ind ])], y_test [np .array ([ind ])]
82
+ preds = model .predict_classes (rowx , verbose = 0 )
83
+ question = rowx [0 ]
84
+ #print(question.shape)
85
+ correct = rowy [0 ]
86
+ #print(correct.shape)
87
+ #print(preds.shape)
88
+ guess = preds [0 ].reshape (SIZE_ROW , SIZE_COL )
89
+ print ('Q' ,'- ' * 25 )
90
+ print (question )
91
+ print ('A' ,'- ' * 25 )
92
+ print (correct )
93
+ if (correct == guess ).all ():
94
+ print ('Y' ,'- ' * 25 )
95
+ else :
96
+ print ('N' ,'- ' * 25 )
97
+ print (guess )
98
+ print ('- ' * 25 )
0 commit comments