-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
82 lines (62 loc) · 2.78 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from tensorflow.keras import layers, models
def get_plain_model(input_shape, output_shape, config):
dropout = config['dropout']
model = models.Sequential([layers.InputLayer(input_shape=input_shape)])
model.add(layers.AveragePooling1D(pool_size=input_shape[0], strides=1))
model.add(layers.Flatten())
for layer_config in config['layers']['dense_layers']:
model.add(layers.Dense(**layer_config))
if dropout:
model.add(layers.Dropout(dropout))
model.add(layers.Dense(output_shape))
return model
def get_conv_model(input_shape, output_shape, config):
dropout = config['dropout']
model = models.Sequential([layers.InputLayer(input_shape=input_shape)])
if config['layers']['spin_average']:
model.add(layers.AveragePooling1D(pool_size=input_shape[0], strides=1))
model.add(layers.Reshape(input_shape + (1,))) # (batch, vals, qudits, 1))
for layer_config in config['layers']['conv_layers']:
model.add(layers.Conv2D(**layer_config))
if dropout:
model.add(layers.Dropout(dropout))
model.add(layers.Flatten())
for layer_config in config['layers']['dense_layers']:
model.add(layers.Dense(**layer_config))
if dropout:
model.add(layers.Dropout(dropout))
model.add(layers.Dense(output_shape))
return model
def get_recurrent_model(input_shape, output_shape, config):
dropout = config['dropout']
if config['recurrent_type'] == 'LSTM':
recurrent_layer = layers.LSTM
elif config['recurrent_type'] == 'GRU':
recurrent_layer = layers.GRU
else:
recurrent_layer = layers.SimpleRNN
model = models.Sequential([layers.InputLayer(input_shape=input_shape)])
if config['layers']['spin_average']:
model.add(layers.AveragePooling1D(pool_size=input_shape[0], strides=1))
model.add(layers.Permute((2, 1))) # (batch, qudits, vals))
for layer_config in config['layers']['recurrent_layers']:
model.add(layers.Bidirectional(recurrent_layer(**layer_config)))
if dropout:
model.add(layers.Dropout(dropout))
for layer_config in config['layers']['dense_layers']:
model.add(layers.Dense(**layer_config))
if dropout:
model.add(layers.Dropout(dropout))
model.add(layers.Dense(output_shape))
return model
def get_model(input_shape, output_shape, config):
model_type = config['model_type']
if model_type=='plain':
model = get_plain_model(input_shape, output_shape, config)
elif model_type=='conv':
model = get_conv_model(input_shape, output_shape, config)
elif model_type=='recurrent':
model = get_recurrent_model(input_shape, output_shape, config)
else:
raise ValueError('Incorrect model type specified.')
return model