Skip to content

Commit 7c8489b

Browse files
Setup retinanet quantization
1 parent 34e2734 commit 7c8489b

File tree

2 files changed

+119
-29
lines changed

2 files changed

+119
-29
lines changed

examples/tensorflow/object_detection/main.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import tensorflow as tf
1919
import numpy as np
2020

21+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
22+
2123
from nncf.tensorflow import AdaptiveCompressionTrainingLoop
2224
from nncf.tensorflow import create_compressed_model
2325
from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager
@@ -47,6 +49,20 @@
4749
from examples.tensorflow.object_detection.models.model_selector import get_predefined_config
4850
from examples.tensorflow.object_detection.models.model_selector import get_model_builder
4951

52+
def keras_model_to_frozen_graph(model):
53+
input_signature = []
54+
for item in model.inputs:
55+
input_signature.append(tf.TensorSpec(item.shape, item.dtype))
56+
concrete_function = tf.function(model).get_concrete_function(input_signature)
57+
frozen_func = convert_variables_to_constants_v2(concrete_function, lower_control_flow=False)
58+
return frozen_func.graph.as_graph_def(add_shapes=True)
59+
60+
61+
def save_model_as_frozen_graph(model, save_path, as_text=False):
62+
frozen_graph = keras_model_to_frozen_graph(model)
63+
save_dir, name = os.path.split(save_path)
64+
tf.io.write_graph(frozen_graph, save_dir, name, as_text=as_text)
65+
5066

5167
def get_argument_parser():
5268
parser = get_common_argument_parser(precision=False,
@@ -311,7 +327,7 @@ def model_eval_fn(model):
311327
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state)
312328
from op_insertion import NNCFWrapperCustom
313329
args = [model]
314-
inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:])
330+
inputs = tf.keras.layers.Input(shape=model.inputs[0].shape[1:], name=model.inputs[0].name.split(':')[0])
315331
outputs = NNCFWrapperCustom(*args, caliblration_dataset=train_dataset,
316332
enable_mirrored_vars_split=False)(inputs)
317333
compress_model = tf.keras.Model(inputs=inputs, outputs=outputs)
@@ -381,9 +397,10 @@ def validate_fn(model, **kwargs):
381397
write_metrics(metric_result['AP'], config.metrics_dump)
382398

383399
if 'export' in config.mode:
384-
save_path, save_format = get_saving_parameters(config)
385-
compression_ctrl.export_model(save_path, save_format)
386-
logger.info("Saved to {}".format(save_path))
400+
save_model_as_frozen_graph(compress_model, config.to_frozen_graph)
401+
#save_path, save_format = get_saving_parameters(config)
402+
#compression_ctrl.export_model(save_path, save_format)
403+
#logger.info("Saved to {}".format(save_path))
387404

388405

389406
def export(config):

op_insertion.py

+98-25
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,28 @@ class InsertionPoint(object):
2424
AFTER_LAYER = 'after'
2525
BEFORE_LAYER = 'before'
2626

27+
@staticmethod
28+
def from_str(input_str):
29+
if input_str == "AFTER_LAYER":
30+
return InsertionPoint.AFTER_LAYER
31+
if input_str == "BEFORE_LAYER":
32+
return InsertionPoint.BEFORE_LAYER
33+
if input_str == "OPERATION_WITH_WEIGHTS":
34+
return InsertionPoint.WEIGHTS
35+
36+
raise RuntimeError('Wrong type of insertion point')
37+
2738

2839
class QuantizationSetup(object):
29-
def __init__(self, signed=None, narrow_range=False, init_value=6):
40+
def __init__(self, signed=True,
41+
narrow_range=False,
42+
per_channel=False,
43+
symmetric=True,
44+
init_value=6):
3045
self.signed = signed
3146
self.narrow_range = narrow_range
47+
self.per_channel = per_channel
48+
self.symmetric = symmetric
3249
self.init_value = init_value
3350

3451

@@ -76,8 +93,46 @@ def __init__(self,
7693
# point_dict['_target_type'].pop('__objclass__')
7794
# res.append(point_dict)
7895
def get_functional_retinanet_fq_placing_simular_to_nncf2_0(self, g):
79-
path = 'examples/tensorflow/object_detection/configs/quantization/retinanet_quantization_layout.json'
80-
layout = json.load(path)
96+
path = 'configs/quantization/retinanet_quantization_layout.json'
97+
with open(path, 'r') as inp:
98+
layout = json.load(inp)
99+
for l in layout:
100+
l.update({'ops': [op for op in g.get_operations() if op.name.startswith(l['_layer_name'] +'/')]})
101+
102+
transformations = []
103+
for op_layout in layout:
104+
layout_name = op_layout['_layer_name']
105+
setup = QuantizationSetup(signed=op_layout['signedness_to_force'] in (True, None),
106+
narrow_range=op_layout['narrow_range'] or op_layout['half_range'],
107+
per_channel=op_layout['per_channel'])
108+
109+
insertion_point = InsertionPoint.from_str(op_layout['_target_type']['_name_'])
110+
if layout_name.startswith('input'):
111+
op = [g.get_operations()[0]]
112+
elif layout_name.startswith('batch_normalization') or layout_name.endswith('bn'):
113+
op = [op for op in op_layout['ops'] if op.type == 'FusedBatchNormV3']
114+
elif layout_name.startswith('l') or layout_name.startswith("post_hoc"):
115+
op_type = 'BiasAdd' if insertion_point == InsertionPoint.AFTER_LAYER else 'Conv2D'
116+
op = [op for op in op_layout['ops'] if op.type == op_type]
117+
elif layout_name.startswith('class') or layout_name.startswith('box'):
118+
# Skip shared conv by now
119+
continue
120+
elif (layout_name.startswith('p') and not layout_name.startswith('post_hoc')) \
121+
or layout_name.startswith('conv2d'):
122+
op = [op for op in op_layout['ops'] if op.type == 'Conv2D']
123+
elif layout_name.startswith('up_sampling'):
124+
op = [op for op in op_layout['ops'] if op.type == 'ResizeNearestNeighbor']
125+
elif any(any(layout_name.split('_')[-i].endswith(x) for i in [1, 2]) for x in ['Relu', 'add']):
126+
op = op_layout['ops']
127+
if 'Relu' in layout_name:
128+
setup.signed = False
129+
else:
130+
raise RuntimeError(f'You forgot about operation {layout_name}')
131+
132+
assert len(op) == 1
133+
transformations.append((op[0], insertion_point, setup))
134+
135+
return transformations
81136

82137
def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g):
83138
"""Hardcode fq placing for examples.classification.test_models.get_KerasLayer_model"""
@@ -108,9 +163,9 @@ def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g):
108163
#
109164
transformations = []
110165
# Transformations for blocks
111-
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in depthwise_conv])
112-
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in project_ops])
113-
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)) for op in expand_ops])
166+
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in depthwise_conv])
167+
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in project_ops])
168+
transformations.extend([(op, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)) for op in expand_ops])
114169

115170
transformations.extend([(op, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False)) for op in depthwise_conv_relu])
116171
transformations.extend([(op, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=True)) for op in project_bn])
@@ -120,14 +175,14 @@ def get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(self, g):
120175
# FQ on inputs
121176
transformations.append((first_conv, InsertionPoint.BEFORE_LAYER, QuantizationSetup(signed=True)))
122177
# FQ on first conv weights
123-
transformations.append((first_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)))
178+
transformations.append((first_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)))
124179
# FQ after first conv relu
125180
transformations.append((first_conv_relu, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False)))
126181
# Transformation for net tail
127-
transformations.append((last_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)))
182+
transformations.append((last_conv, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)))
128183
transformations.append((last_conv_relu, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False)))
129184
transformations.append((avg_pool, InsertionPoint.AFTER_LAYER, QuantizationSetup(signed=False)))
130-
transformations.append((prediction_mul, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True)))
185+
transformations.append((prediction_mul, InsertionPoint.WEIGHTS, QuantizationSetup(signed=True, narrow_range=False)))
131186
assert len(transformations) == 117
132187

133188
return transformations
@@ -136,8 +191,12 @@ def build(self, input_shape=None):
136191
for training, model in zip([True, False], [self.trainable_model, self.eval_model]):
137192
if self.model_type != ModelType.KerasLayer:
138193
tf_f = tf.function(lambda x: model.orig_model.call(x, training=training))
139-
concrete = tf_f.get_concrete_function(*[tf.TensorSpec(input_shape, tf.float32)])
194+
input_signature = []
195+
for item in model.orig_model.inputs:
196+
input_signature.append(tf.TensorSpec(item.shape, item.dtype))
140197

198+
concrete = tf_f.get_concrete_function(input_signature)
199+
structured_outputs = concrete.structured_outputs
141200
sorted_vars = get_sorted_on_captured_vars(concrete)
142201
model.mirrored_variables = model.orig_model.variables
143202

@@ -150,10 +209,6 @@ def build(self, input_shape=None):
150209

151210
sorted_vars = get_sorted_on_captured_vars(concrete)
152211
model.mirrored_variables = self.create_mirrored_variables(sorted_vars)
153-
###
154-
### Generated weights preprocessing
155-
###
156-
### Insert compression operation
157212

158213
if not self.initial_model_weights:
159214
self.initial_model_weights = self.get_numpy_weights_list(sorted_vars)
@@ -173,31 +228,32 @@ def build(self, input_shape=None):
173228
# Add new op to layer
174229
if not self.ops_vars_created:
175230
self.op_vars = []
176-
enable_quantization = False
231+
enable_quantization = True
177232
if enable_quantization:
178233
new_vars = []
179234
transformations = self.get_functional_retinanet_fq_placing_simular_to_nncf2_0(concrete.graph)
180235
#transformations = self.get_keras_layer_mobilenet_v2_fq_placing_simular_to_nncf2_0(concrete.graph)
181236
if training:
182-
pass
183-
#self.initialize_trainsformations(concrete, transformations)
237+
#pass
238+
self.initialize_trainsformations(concrete, transformations)
184239

185240
with concrete.graph.as_default() as g:
186241
# Insert given transformations
187242
for op, insertion_point, setup in transformations:
188243
def fq_creation(input_tensor, name):
189244
return create_fq_with_weights(input_tensor=input_tensor,
245+
per_channel=setup.per_channel,
190246
name=name,
191247
signed=setup.signed,
192248
init_value=setup.init_value,
193249
narrow_range=setup.narrow_range)
194250

195251
if insertion_point == InsertionPoint.AFTER_LAYER:
196-
new_vars.append(insert_op_after(g, op, 0, fq_creation, op.name))
252+
new_vars.append(insert_op_after(g, op, 0, fq_creation, f'{op.name}_after_layer'))
197253
elif insertion_point == InsertionPoint.BEFORE_LAYER:
198254
new_vars.append(insert_op_before(g, op, 0, fq_creation, f'{op.name}_before_layer'))
199255
elif insertion_point == InsertionPoint.WEIGHTS:
200-
new_vars.append(insert_op_before(g, op, 1, fq_creation, op.name))
256+
new_vars.append(insert_op_before(g, op, 1, fq_creation, f'{op.name}_weights'))
201257
else:
202258
raise RuntimeError('Wrong insertion point in quantization algo')
203259

@@ -220,12 +276,22 @@ def fq_creation(input_tensor, name):
220276
for new_var, (_, placeholder) in zip(new_ops_vars, old_captures[-len(self.op_vars):]):
221277
new_captures.append((new_var.handle, placeholder))
222278
new_variables = [v for v in concrete.variables] + new_ops_vars
279+
if len(new_variables) != len(new_captures):
280+
raise RuntimeError('Len of the new vars should be the same as len'
281+
' of new captures (possible some compression weights missing)')
282+
223283
concrete = make_new_func(concrete.graph.as_graph_def(),
224284
new_captures,
225285
new_variables,
226286
concrete.inputs,
227287
concrete.outputs)
228288

289+
if structured_outputs is not None:
290+
# The order should be the same because
291+
# we use concrete.outputs when building new concrete function
292+
#outputs_list = nest.flatten(structured_outputs, expand_composites=True)
293+
concrete._func_graph.structured_outputs = \
294+
nest.pack_sequence_as(structured_outputs, concrete.outputs, expand_composites=True)
229295
model.output_tensor = concrete.graph.outputs
230296
model.fn_train = concrete
231297

@@ -297,14 +363,14 @@ def initialize_trainsformations(self, concrete, trainsformations):
297363
min_val, max_val = self.get_min_max_op_weights(concrete.graph, op, concrete.inputs,
298364
self.initial_model_weights)
299365
setup.init_value = max(abs(min_val), abs(max_val))
300-
setup.narrow_range = True
366+
#setup.narrow_range = True
301367

302368
if self.calibration_dataset is None:
303369
return
304370

305371
outputs = []
306372
activation_transformations = [t for t in trainsformations if t[1] != InsertionPoint.WEIGHTS]
307-
for op, insertion_point, setup in activation_transformations:
373+
for op, _, _ in activation_transformations:
308374
outputs.append(op.outputs[0])
309375

310376
# Create concrete function with outputs from each activation
@@ -324,7 +390,7 @@ def initialize_trainsformations(self, concrete, trainsformations):
324390
# Update quantization setup
325391
for i, (_, _, setup) in enumerate(activation_transformations):
326392
setup.init_value = max(abs(np.mean(mins[i])), abs(np.mean(maxs[i])))
327-
setup.narrow_range = False
393+
#setup.narrow_range = False
328394

329395
def get_min_max_op_weights(self, graph, op, placeholders, np_vars):
330396
try:
@@ -480,9 +546,11 @@ def insert_op_after(graph, target_parent, output_index, node_creation_fn, name):
480546
return node_weights
481547

482548

483-
def create_fq_with_weights(input_tensor, name, signed, init_value, narrow_range):
549+
def create_fq_with_weights(input_tensor, per_channel, name, signed, init_value, narrow_range):
484550
"""Should be called in graph context"""
485551
with variable_scope.variable_scope('new_node'):
552+
# Should check if variable already exist
553+
# if it exist through error
486554
scale = variable_scope.get_variable(
487555
f'scale_{name}',
488556
shape=(),
@@ -491,8 +559,13 @@ def create_fq_with_weights(input_tensor, name, signed, init_value, narrow_range)
491559
trainable=True)
492560

493561
min = -scale if signed else 0.
494-
output_tensor = tf.quantization.fake_quant_with_min_max_vars(input_tensor, min, scale,
495-
narrow_range=narrow_range)
562+
if False:#per_channel:
563+
# Per channel not implemented yet
564+
output_tensor = tf.quantization.fake_quant_with_min_max_vars_per_channel(input_tensor, min, scale,
565+
narrow_range=narrow_range)
566+
else:
567+
output_tensor = tf.quantization.fake_quant_with_min_max_vars(input_tensor, min, scale,
568+
narrow_range=narrow_range)
496569
return output_tensor, scale
497570

498571

0 commit comments

Comments
 (0)