-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_results.py
35 lines (30 loc) · 1.08 KB
/
plot_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from matplotlib import pyplot as plt
from tensorflow.keras import utils
import numpy as np
from os.path import join
def plot_history(history, log_scale, model_dir):
h = history.history
keys = [key for key in h.keys() if not key.startswith('val_')]
for key in keys:
y = np.array([h[key], h['val_' + key]])
fig, ax = plt.subplots()
if log_scale:
ax.semilogy(y[0], label='Training')
ax.semilogy(y[1], label='Validation')
else:
ax.plot(y[0], label='Training')
ax.plot(y[1], label='Validation')
ax.set_ylabel(key)
ax.set_xlabel('Epoch')
ax.legend()
fig.savefig(join(model_dir, key + '.png'))
def plot_results(model, history, model_dir, config):
log_scale = config['log_scale'] if config['log_scale'] else False
plot_history(history, log_scale, model_dir)
utils.plot_model(
model,
to_file=join(model_dir, 'model_architecture.png'),
show_shapes=True,
show_layer_names=True,
show_layer_activations=True
)