Skip to content

Commit b350572

Browse files
Fix retinanet training
1 parent 7c8489b commit b350572

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

examples/tensorflow/common/object_detection/base_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, params):
4949
# One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training when loading from the checkpoint
5050
# RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
5151
self._frozen_variable_prefix = ""
52-
params_train_regularization_variable_regex = r'.*(kernel|weight):0$'
52+
params_train_regularization_variable_regex = r'.*(kernel|weight|kernel_mirrored|weight_mirrored):0$'
5353
self._regularization_var_regex = params_train_regularization_variable_regex
5454
self._l2_weight_decay = params.weight_decay
5555

examples/tensorflow/object_detection/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def model_eval_fn(model):
329329
args = [model]
330330
inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:], name=model.inputs[0].name.split(':')[0])
331331
outputs = NNCFWrapperCustom(*args, caliblration_dataset=train_dataset,
332-
enable_mirrored_vars_split=False)(inputs)
332+
enable_mirrored_vars_split=True)(inputs)
333333
compress_model = tf.keras.Model(inputs=inputs, outputs=outputs)
334334

335335
scheduler = build_scheduler(

examples/tensorflow/object_detection/models/retinanet_model.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,14 @@ def build_outputs(self, inputs, is_training):
6868

6969
return model_outputs
7070

71+
@staticmethod
72+
def get_zero_replica_from_mirrored_var(var):
73+
return var._get_replica(0)
74+
7175
def build_loss_fn(self, keras_model, compression_loss_fn):
7276
#filter_fn = self.make_filter_trainable_variables_fn()
7377
#trainable_variables = filter_fn(keras_model.trainable_variables)
74-
trainable_variables = [v for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable]
78+
trainable_variables = [self.get_zero_replica_from_mirrored_var(v) for v in keras_model.layers[1].trainable_model.mirrored_variables if v.trainable]
7579

7680
def _total_loss_fn(labels, outputs):
7781
cls_loss = self._cls_loss_fn(outputs['cls_outputs'],

op_insertion.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,11 @@ def build(self, input_shape=None):
198198
concrete = tf_f.get_concrete_function(input_signature)
199199
structured_outputs = concrete.structured_outputs
200200
sorted_vars = get_sorted_on_captured_vars(concrete)
201-
model.mirrored_variables = model.orig_model.variables
201+
if isinstance(model.orig_model.variables[0], MirroredVariable):
202+
model.mirrored_variables = model.orig_model.variables
203+
else:
204+
# Case when model build before replica context
205+
model.mirrored_variables = self.create_mirrored_variables(sorted_vars)
202206

203207
else:
204208
concrete = make_new_func(model.graph_def,
@@ -209,6 +213,7 @@ def build(self, input_shape=None):
209213

210214
sorted_vars = get_sorted_on_captured_vars(concrete)
211215
model.mirrored_variables = self.create_mirrored_variables(sorted_vars)
216+
structured_outputs = None
212217

213218
if not self.initial_model_weights:
214219
self.initial_model_weights = self.get_numpy_weights_list(sorted_vars)
@@ -231,8 +236,8 @@ def build(self, input_shape=None):
231236
enable_quantization = True
232237
if enable_quantization:
233238
new_vars = []
234-
transformations = self.get_functional_retinanet_fq_placing_simular_to_nncf2_0(concrete.graph)
235-
#transformations = self.get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(concrete.graph)
239+
#transformations = self.get_functional_retinanet_fq_placing_simular_to_nncf2_0(concrete.graph)
240+
transformations = self.get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(concrete.graph)
236241
if training:
237242
#pass
238243
self.initialize_trainsformations(concrete, transformations)
@@ -350,6 +355,14 @@ def call(self, inputs, training=None):
350355
model_obj.fn_train.inputs,
351356
model_obj.output_tensor)
352357

358+
if model_obj.fn_train.structured_outputs is not None:
359+
# The order should be the same because
360+
# we use concrete.outputs when building new concrete function
361+
#outputs_list = nest.flatten(structured_outputs, expand_composites=True)
362+
fn_train._func_graph.structured_outputs = \
363+
nest.pack_sequence_as(model_obj.fn_train.structured_outputs,
364+
fn_train.outputs,
365+
expand_composites=True)
353366
return fn_train(inputs)
354367

355368
def initialize_trainsformations(self, concrete, trainsformations):
@@ -367,7 +380,7 @@ def initialize_trainsformations(self, concrete, trainsformations):
367380

368381
if self.calibration_dataset is None:
369382
return
370-
383+
return
371384
outputs = []
372385
activation_transformations = [t for t in trainsformations if t[1] != InsertionPoint.WEIGHTS]
373386
for op, _, _ in activation_transformations:

0 commit comments

Comments
 (0)