@@ -198,7 +198,11 @@ def build(self, input_shape=None):
198
198
concrete = tf_f .get_concrete_function (input_signature )
199
199
structured_outputs = concrete .structured_outputs
200
200
sorted_vars = get_sorted_on_captured_vars (concrete )
201
- model .mirrored_variables = model .orig_model .variables
201
+ if isinstance (model .orig_model .variables [0 ], MirroredVariable ):
202
+ model .mirrored_variables = model .orig_model .variables
203
+ else :
204
+ # Case when model build before replica context
205
+ model .mirrored_variables = self .create_mirrored_variables (sorted_vars )
202
206
203
207
else :
204
208
concrete = make_new_func (model .graph_def ,
@@ -209,6 +213,7 @@ def build(self, input_shape=None):
209
213
210
214
sorted_vars = get_sorted_on_captured_vars (concrete )
211
215
model .mirrored_variables = self .create_mirrored_variables (sorted_vars )
216
+ structured_outputs = None
212
217
213
218
if not self .initial_model_weights :
214
219
self .initial_model_weights = self .get_numpy_weights_list (sorted_vars )
@@ -231,8 +236,8 @@ def build(self, input_shape=None):
231
236
enable_quantization = True
232
237
if enable_quantization :
233
238
new_vars = []
234
- transformations = self .get_functional_retinanet_fq_placing_simular_to_nncf2_0 (concrete .graph )
235
- # transformations = self.get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(concrete.graph)
239
+ # transformations = self.get_functional_retinanet_fq_placing_simular_to_nncf2_0(concrete.graph)
240
+ transformations = self .get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0 (concrete .graph )
236
241
if training :
237
242
#pass
238
243
self .initialize_trainsformations (concrete , transformations )
@@ -350,6 +355,14 @@ def call(self, inputs, training=None):
350
355
model_obj .fn_train .inputs ,
351
356
model_obj .output_tensor )
352
357
358
+ if model_obj .fn_train .structured_outputs is not None :
359
+ # The order should be the same because
360
+ # we use concrete.outputs when building new concrete function
361
+ #outputs_list = nest.flatten(structured_outputs, expand_composites=True)
362
+ fn_train ._func_graph .structured_outputs = \
363
+ nest .pack_sequence_as (model_obj .fn_train .structured_outputs ,
364
+ fn_train .outputs ,
365
+ expand_composites = True )
353
366
return fn_train (inputs )
354
367
355
368
def initialize_trainsformations (self , concrete , trainsformations ):
@@ -367,7 +380,7 @@ def initialize_trainsformations(self, concrete, trainsformations):
367
380
368
381
if self .calibration_dataset is None :
369
382
return
370
-
383
+ return
371
384
outputs = []
372
385
activation_transformations = [t for t in trainsformations if t [1 ] != InsertionPoint .WEIGHTS ]
373
386
for op , _ , _ in activation_transformations :
0 commit comments