Skip to content

Commit 86d5b43

Browse files
committed
keras & tf models: refactor - move aux classes below main model class
1 parent cbc6487 commit 86d5b43

File tree

2 files changed

+130
-129
lines changed

2 files changed

+130
-129
lines changed

keras_model.py

+95-94
Original file line numberDiff line numberDiff line change
@@ -24,98 +24,6 @@
2424
ModelCheckpointSaverCallback, MultiBatchCallback, ModelTrainingProgressLoggerCallback
2525

2626

27-
class ModelEvaluationCallback(MultiBatchCallback):
28-
"""
29-
This callback is passed to the `model.fit()` call.
30-
It is responsible to trigger model evaluation during the training.
31-
The reason we use a callback and not just passing validation data to `model.fit()` is because:
32-
(i) the training model is different than the evaluation model for efficiency considerations;
33-
(ii) we want to control the logging format;
34-
(iii) we want the evaluation to occur once per 1K batches (rather than only once per epoch).
35-
"""
36-
37-
def __init__(self, code2vec_model: 'Code2VecModel'):
38-
self.code2vec_model = code2vec_model
39-
self.avg_eval_duration: Optional[int] = None
40-
super(ModelEvaluationCallback, self).__init__(self.code2vec_model.config.NUM_TRAIN_BATCHES_TO_EVALUATE)
41-
42-
def on_epoch_end(self, epoch, logs=None):
43-
self.perform_evaluation()
44-
45-
def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
46-
self.perform_evaluation()
47-
48-
def perform_evaluation(self):
49-
if self.avg_eval_duration is None:
50-
self.code2vec_model.log('Evaluating...')
51-
else:
52-
self.code2vec_model.log('Evaluating... (takes ~{})'.format(
53-
str(datetime.timedelta(seconds=int(self.avg_eval_duration)))))
54-
eval_start_time = time.time()
55-
evaluation_results = self.code2vec_model.evaluate()
56-
eval_duration = time.time() - eval_start_time
57-
if self.avg_eval_duration is None:
58-
self.avg_eval_duration = eval_duration
59-
else:
60-
self.avg_eval_duration = eval_duration * 0.5 + self.avg_eval_duration * 0.5
61-
self.code2vec_model.log('Done evaluating (took {}). Evaluation results:'.format(
62-
str(datetime.timedelta(seconds=int(eval_duration)))))
63-
64-
self.code2vec_model.log(
65-
' loss: {loss:.4f}, f1: {f1:.4f}, recall: {recall:.4f}, precision: {precision:.4f}'.format(
66-
loss=evaluation_results.loss, f1=evaluation_results.subtoken_f1,
67-
recall=evaluation_results.subtoken_recall, precision=evaluation_results.subtoken_precision))
68-
top_k_acc_formated = ['top{}: {:.4f}'.format(i, acc) for i, acc in enumerate(evaluation_results.topk_acc, start=1)]
69-
for top_k_acc_chunk in common.chunks(top_k_acc_formated, 5):
70-
self.code2vec_model.log(' ' + (', '.join(top_k_acc_chunk)))
71-
72-
73-
class _KerasModelInputTensorsFormer(ModelInputTensorsFormer):
74-
"""
75-
An instance of this class is passed to the reader in order to help the reader to construct the input
76-
in the form that the model expects to receive it.
77-
This class also enables conveniently & clearly access input parts by their field names.
78-
eg: 'tensors.path_indices' instead if 'tensors[1]'.
79-
This allows the input tensors to be passed as pure tuples along the computation graph, while the
80-
python functions that construct the graph can easily (and clearly) access tensors.
81-
"""
82-
83-
def __init__(self, estimator_action: EstimatorAction):
84-
self.estimator_action = estimator_action
85-
86-
def to_model_input_form(self, input_tensors: ReaderInputTensors):
87-
inputs = (input_tensors.path_source_token_indices, input_tensors.path_indices,
88-
input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
89-
if self.estimator_action.is_train:
90-
targets = input_tensors.target_index
91-
else:
92-
targets = {'target_index': input_tensors.target_index, 'target_string': input_tensors.target_string}
93-
if self.estimator_action.is_predict:
94-
inputs += (input_tensors.path_source_token_strings, input_tensors.path_strings,
95-
input_tensors.path_target_token_strings)
96-
return inputs, targets
97-
98-
def from_model_input_form(self, input_row) -> ReaderInputTensors:
99-
inputs, targets = input_row
100-
return ReaderInputTensors(
101-
path_source_token_indices=inputs[0],
102-
path_indices=inputs[1],
103-
path_target_token_indices=inputs[2],
104-
context_valid_mask=inputs[3],
105-
target_index=targets if self.estimator_action.is_train else targets['target_index'],
106-
target_string=targets['target_string'] if not self.estimator_action.is_train else None,
107-
path_source_token_strings=inputs[4] if self.estimator_action.is_predict else None,
108-
path_strings=inputs[5] if self.estimator_action.is_predict else None,
109-
path_target_token_strings=inputs[6] if self.estimator_action.is_predict else None
110-
)
111-
112-
113-
"""Used for convenient-and-clear access to raw prediction result parts (by their names)."""
114-
KerasPredictionModelOutput = namedtuple(
115-
'KerasModelOutput', ['target_index', 'code_vectors', 'attention_weights',
116-
'topk_predicted_words', 'topk_predicted_words_scores'])
117-
118-
11927
class Code2VecModel(Code2VecModelBase):
12028
def __init__(self, config: Config):
12129
self.keras_train_model: Optional[keras.Model] = None
@@ -175,7 +83,8 @@ def _create_keras_model(self):
17583
# We use another dedicated Keras model for evaluation.
17684
# The evaluation model outputs the `topk_predicted_words` as a 2nd output.
17785
# The separation between train and eval models is for efficiency.
178-
self.keras_eval_model = keras.Model(inputs=inputs, outputs=[target_index, topk_predicted_words])
86+
self.keras_eval_model = keras.Model(
87+
inputs=inputs, outputs=[target_index, topk_predicted_words], name="code2vec-keras-model")
17988

18089
# We use another dedicated Keras function to produce predictions.
18190
# It have additional outputs than the original model.
@@ -327,7 +236,7 @@ def _save_inner_model(self, path):
327236
def _create_inner_model(self):
328237
self._create_keras_model()
329238
self._compile_keras_model()
330-
self.keras_train_model.summary()
239+
self.keras_train_model.summary(print_fn=self.log)
331240

332241
def _load_inner_model(self):
333242
self._create_keras_model()
@@ -412,3 +321,95 @@ def _create_lookup_tables(self):
412321

413322
def _initialize(self):
414323
self._create_lookup_tables()
324+
325+
326+
class ModelEvaluationCallback(MultiBatchCallback):
327+
"""
328+
This callback is passed to the `model.fit()` call.
329+
It is responsible to trigger model evaluation during the training.
330+
The reason we use a callback and not just passing validation data to `model.fit()` is because:
331+
(i) the training model is different than the evaluation model for efficiency considerations;
332+
(ii) we want to control the logging format;
333+
(iii) we want the evaluation to occur once per 1K batches (rather than only once per epoch).
334+
"""
335+
336+
def __init__(self, code2vec_model: 'Code2VecModel'):
337+
self.code2vec_model = code2vec_model
338+
self.avg_eval_duration: Optional[int] = None
339+
super(ModelEvaluationCallback, self).__init__(self.code2vec_model.config.NUM_TRAIN_BATCHES_TO_EVALUATE)
340+
341+
def on_epoch_end(self, epoch, logs=None):
342+
self.perform_evaluation()
343+
344+
def on_multi_batch_end(self, batch, logs, multi_batch_elapsed):
345+
self.perform_evaluation()
346+
347+
def perform_evaluation(self):
348+
if self.avg_eval_duration is None:
349+
self.code2vec_model.log('Evaluating...')
350+
else:
351+
self.code2vec_model.log('Evaluating... (takes ~{})'.format(
352+
str(datetime.timedelta(seconds=int(self.avg_eval_duration)))))
353+
eval_start_time = time.time()
354+
evaluation_results = self.code2vec_model.evaluate()
355+
eval_duration = time.time() - eval_start_time
356+
if self.avg_eval_duration is None:
357+
self.avg_eval_duration = eval_duration
358+
else:
359+
self.avg_eval_duration = eval_duration * 0.5 + self.avg_eval_duration * 0.5
360+
self.code2vec_model.log('Done evaluating (took {}). Evaluation results:'.format(
361+
str(datetime.timedelta(seconds=int(eval_duration)))))
362+
363+
self.code2vec_model.log(
364+
' loss: {loss:.4f}, f1: {f1:.4f}, recall: {recall:.4f}, precision: {precision:.4f}'.format(
365+
loss=evaluation_results.loss, f1=evaluation_results.subtoken_f1,
366+
recall=evaluation_results.subtoken_recall, precision=evaluation_results.subtoken_precision))
367+
top_k_acc_formated = ['top{}: {:.4f}'.format(i, acc) for i, acc in enumerate(evaluation_results.topk_acc, start=1)]
368+
for top_k_acc_chunk in common.chunks(top_k_acc_formated, 5):
369+
self.code2vec_model.log(' ' + (', '.join(top_k_acc_chunk)))
370+
371+
372+
class _KerasModelInputTensorsFormer(ModelInputTensorsFormer):
373+
"""
374+
An instance of this class is passed to the reader in order to help the reader to construct the input
375+
in the form that the model expects to receive it.
376+
This class also enables conveniently & clearly access input parts by their field names.
377+
eg: 'tensors.path_indices' instead if 'tensors[1]'.
378+
This allows the input tensors to be passed as pure tuples along the computation graph, while the
379+
python functions that construct the graph can easily (and clearly) access tensors.
380+
"""
381+
382+
def __init__(self, estimator_action: EstimatorAction):
383+
self.estimator_action = estimator_action
384+
385+
def to_model_input_form(self, input_tensors: ReaderInputTensors):
386+
inputs = (input_tensors.path_source_token_indices, input_tensors.path_indices,
387+
input_tensors.path_target_token_indices, input_tensors.context_valid_mask)
388+
if self.estimator_action.is_train:
389+
targets = input_tensors.target_index
390+
else:
391+
targets = {'target_index': input_tensors.target_index, 'target_string': input_tensors.target_string}
392+
if self.estimator_action.is_predict:
393+
inputs += (input_tensors.path_source_token_strings, input_tensors.path_strings,
394+
input_tensors.path_target_token_strings)
395+
return inputs, targets
396+
397+
def from_model_input_form(self, input_row) -> ReaderInputTensors:
398+
inputs, targets = input_row
399+
return ReaderInputTensors(
400+
path_source_token_indices=inputs[0],
401+
path_indices=inputs[1],
402+
path_target_token_indices=inputs[2],
403+
context_valid_mask=inputs[3],
404+
target_index=targets if self.estimator_action.is_train else targets['target_index'],
405+
target_string=targets['target_string'] if not self.estimator_action.is_train else None,
406+
path_source_token_strings=inputs[4] if self.estimator_action.is_predict else None,
407+
path_strings=inputs[5] if self.estimator_action.is_predict else None,
408+
path_target_token_strings=inputs[6] if self.estimator_action.is_predict else None
409+
)
410+
411+
412+
"""Used for convenient-and-clear access to raw prediction result parts (by their names)."""
413+
KerasPredictionModelOutput = namedtuple(
414+
'KerasModelOutput', ['target_index', 'code_vectors', 'attention_weights',
415+
'topk_predicted_words', 'topk_predicted_words_scores'])

tensorflow_model.py

+35-35
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,6 @@
1515
tf.compat.v1.disable_eager_execution()
1616

1717

18-
class _TFTrainModelInputTensorsFormer(ModelInputTensorsFormer):
19-
def to_model_input_form(self, input_tensors: ReaderInputTensors):
20-
return input_tensors.target_index, input_tensors.path_source_token_indices, input_tensors.path_indices, \
21-
input_tensors.path_target_token_indices, input_tensors.context_valid_mask
22-
23-
def from_model_input_form(self, input_row) -> ReaderInputTensors:
24-
return ReaderInputTensors(
25-
target_index=input_row[0],
26-
path_source_token_indices=input_row[1],
27-
path_indices=input_row[2],
28-
path_target_token_indices=input_row[3],
29-
context_valid_mask=input_row[4]
30-
)
31-
32-
33-
class _TFEvaluateModelInputTensorsFormer(ModelInputTensorsFormer):
34-
def to_model_input_form(self, input_tensors: ReaderInputTensors):
35-
return input_tensors.target_string, input_tensors.path_source_token_indices, input_tensors.path_indices, \
36-
input_tensors.path_target_token_indices, input_tensors.context_valid_mask, \
37-
input_tensors.path_source_token_strings, input_tensors.path_strings, \
38-
input_tensors.path_target_token_strings
39-
40-
def from_model_input_form(self, input_row) -> ReaderInputTensors:
41-
return ReaderInputTensors(
42-
target_string=input_row[0],
43-
path_source_token_indices=input_row[1],
44-
path_indices=input_row[2],
45-
path_target_token_indices=input_row[3],
46-
context_valid_mask=input_row[4],
47-
path_source_token_strings=input_row[5],
48-
path_strings=input_row[6],
49-
path_target_token_strings=input_row[7]
50-
)
51-
52-
5318
class Code2VecModel(Code2VecModelBase):
5419
def __init__(self, config: Config):
5520
self.sess = tf.compat.v1.Session()
@@ -528,3 +493,38 @@ def update_batch(self, results):
528493
@property
529494
def topk_correct_predictions(self):
530495
return self.nr_correct_predictions / self.nr_predictions
496+
497+
498+
class _TFTrainModelInputTensorsFormer(ModelInputTensorsFormer):
499+
def to_model_input_form(self, input_tensors: ReaderInputTensors):
500+
return input_tensors.target_index, input_tensors.path_source_token_indices, input_tensors.path_indices, \
501+
input_tensors.path_target_token_indices, input_tensors.context_valid_mask
502+
503+
def from_model_input_form(self, input_row) -> ReaderInputTensors:
504+
return ReaderInputTensors(
505+
target_index=input_row[0],
506+
path_source_token_indices=input_row[1],
507+
path_indices=input_row[2],
508+
path_target_token_indices=input_row[3],
509+
context_valid_mask=input_row[4]
510+
)
511+
512+
513+
class _TFEvaluateModelInputTensorsFormer(ModelInputTensorsFormer):
514+
def to_model_input_form(self, input_tensors: ReaderInputTensors):
515+
return input_tensors.target_string, input_tensors.path_source_token_indices, input_tensors.path_indices, \
516+
input_tensors.path_target_token_indices, input_tensors.context_valid_mask, \
517+
input_tensors.path_source_token_strings, input_tensors.path_strings, \
518+
input_tensors.path_target_token_strings
519+
520+
def from_model_input_form(self, input_row) -> ReaderInputTensors:
521+
return ReaderInputTensors(
522+
target_string=input_row[0],
523+
path_source_token_indices=input_row[1],
524+
path_indices=input_row[2],
525+
path_target_token_indices=input_row[3],
526+
context_valid_mask=input_row[4],
527+
path_source_token_strings=input_row[5],
528+
path_strings=input_row[6],
529+
path_target_token_strings=input_row[7]
530+
)

0 commit comments

Comments
 (0)