Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF1 model quantization #9

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions contrib/input_to_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic to update a Tensorflow model graph with quantization operations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections


# Skip all operations that are backprop related or export summaries.
SKIPPED_PREFIXES = ()


class InputToOps(object):
"""Holds a mapping from tensor's name to ops that take it as input."""

def __init__(self, graph):
"""Initializes mapping from tensor's name to ops that take it.

Helps find edges between ops faster and avoids iterating over the whole
graph. The mapping is of type Dict[str, Set[tf.Operation]].

Note: while inserting operations into the graph, we do not update the
mapping, assuming that insertion points in the graph are never adjacent.
With that restriction, an out of date mapping still works fine.

Args:
graph: Graph to process.
"""
self.mapping = collections.defaultdict(set)
for op in (op for op in graph.get_operations()):
if op.name.startswith(SKIPPED_PREFIXES):
continue
for op_input in op.inputs:
self.mapping[op_input.ref()].add(op)

def ConsumerOperations(self, producer_op):
"""Looks through outputs of producer_op, finds ops that take them as input.

Args:
producer_op: Operation containing outputs to process.

Returns:
A Set[Operation] containing all operations taking input from producer_op
outputs.
"""
result = set()
for inp in producer_op.outputs:
result.update(self.mapping[inp.ref()])
return result
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

"batch_size": 256,
"epochs": 15,
"num_classes": 1001,
"dataset_preprocessing_preset": "imagenet2012_slim",

"optimizer": {
"type": "Adam",
Expand All @@ -18,20 +20,6 @@
},

"dataset": "imagenet2012",
"dataset_type": "tfds",
"dataset_type": "tfrecords"

"compression": {
"algorithm": "quantization",
"initializer": {
"batchnorm_adaptation": {
"num_bn_adaptation_samples": 2048
}
},
"weights": {
"per_channel": false
},
"activations": {
"per_channel": false
}
}
}
144 changes: 103 additions & 41 deletions examples/tensorflow/classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,29 @@
limitations under the License.
"""

import os
import sys
import os.path as osp
from pathlib import Path

import tensorflow as tf
import tensorflow_addons as tfa

from pathlib import Path
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

from nncf.config.utils import is_accuracy_aware_training
from nncf.tensorflow.helpers.model_creation import create_compressed_model
from nncf.tensorflow import create_compression_callbacks
from nncf.tensorflow.helpers.model_manager import TFOriginalModelManager
from nncf.tensorflow.initialization import register_default_init_args
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.utils.state import TFCompressionStateLoader

from examples.tensorflow.classification.datasets.builder import DatasetBuilder
from examples.tensorflow.common.argparser import get_common_argument_parser
from examples.tensorflow.common.callbacks import get_callbacks
from examples.tensorflow.common.callbacks import get_progress_bar
from examples.tensorflow.common.distributed import get_distribution_strategy
from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.model_loader import get_model
from examples.tensorflow.common.model_loader import get_model as get_model_old
from examples.tensorflow.common.optimizer import build_optimizer
from examples.tensorflow.common.sample_config import create_sample_config
from examples.tensorflow.common.scheduler import build_scheduler
Expand All @@ -43,6 +44,37 @@
from examples.tensorflow.common.utils import serialize_config
from examples.tensorflow.common.utils import serialize_cli_args
from examples.tensorflow.common.utils import write_metrics
from examples.tensorflow.classification.test_models import get_KerasLayer_model
from examples.tensorflow.classification.test_models import get_model
from examples.tensorflow.classification.test_models import ModelType

# KerasLayer with NNCFWrapper 1 epoch
# runs/MobileNetV2_imagenet2012/2021-07-21__14-22-44
# Keras Layer pure 1 epoch
# runs/MobileNetV2_imagenet2012/2021-07-21__14-53-04


def keras_model_to_frozen_graph(model):
input_signature = []
for item in model.inputs:
input_signature.append(tf.TensorSpec(item.shape, item.dtype))
concrete_function = tf.function(model).get_concrete_function(input_signature)
frozen_func = convert_variables_to_constants_v2(concrete_function, lower_control_flow=False)
return frozen_func.graph.as_graph_def(add_shapes=True)


def save_model_as_frozen_graph(model, save_path, as_text=False):
frozen_graph = keras_model_to_frozen_graph(model)
save_dir, name = os.path.split(save_path)
tf.io.write_graph(frozen_graph, save_dir, name, as_text=as_text)


class DummyContextManager:
def __enter__(self):
pass

def __exit__(self, *args):
pass


def get_argument_parser():
Expand All @@ -64,6 +96,11 @@ def get_argument_parser():
help="Use pretrained models from the tf.keras.applications",
action="store_true",
)
parser.add_argument(
"--model_type",
choices=[ModelType.KerasLayer, ModelType.FuncModel, ModelType.SubClassModel],
default=ModelType.KerasLayer,
help="Type of mobilenetV2 model which should be quantized.")
return parser


Expand Down Expand Up @@ -152,12 +189,18 @@ def run(config):
if config.metrics_dump is not None:
write_metrics(0, config.metrics_dump)

model_fn, model_params = get_model(config.model,
model_fn, model_params = get_model_old(config.model,
input_shape=config.get('input_info', {}).get('sample_size', None),
num_classes=config.get('num_classes', get_num_classes(config.dataset)),
pretrained=config.get('pretrained', False),
weights=config.get('weights', None))

if config.model_type == ModelType.KerasLayer:
#args = None
args = get_KerasLayer_model()
else:
args = None

builders = get_dataset_builders(config, strategy.num_replicas_in_sync)
datasets = [builder.build() for builder in builders]

Expand Down Expand Up @@ -188,10 +231,20 @@ def run(config):
if resume_training:
compression_state = load_compression_state(config.ckpt_path)

with TFOriginalModelManager(model_fn, **model_params) as model:
with DummyContextManager():
with strategy.scope():
compression_ctrl, compress_model = create_compressed_model(model, nncf_config, compression_state)
compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir=config.log_dir)
if not args:
args = get_model(config.model_type)

from op_insertion import NNCFWrapperCustom
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(224, 224, 3)),
NNCFWrapperCustom(*args, caliblration_dataset=train_dataset),
#args[0]['layer'],
tf.keras.layers.Activation('softmax')
])
#compression_ctrl, compress_model = create_compressed_model(model, nncf_config, compression_state)
compress_model = model

scheduler = build_scheduler(
config=config,
Expand All @@ -202,13 +255,13 @@ def run(config):

loss_obj = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)

compress_model.add_loss(compression_ctrl.loss)
#compress_model.add_loss(compression_ctrl.loss)

metrics = [
tf.keras.metrics.CategoricalAccuracy(name='acc@1'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc@5'),
tfa.metrics.MeanMetricWrapper(loss_obj, name='ce_loss'),
tfa.metrics.MeanMetricWrapper(compression_ctrl.loss, name='cr_loss')
#tfa.metrics.MeanMetricWrapper(compression_ctrl.loss, name='cr_loss')
]

compress_model.compile(optimizer=optimizer,
Expand All @@ -218,14 +271,17 @@ def run(config):

compress_model.summary()

checkpoint = tf.train.Checkpoint(model=compress_model,
compression_state=TFCompressionState(compression_ctrl))
checkpoint = tf.train.Checkpoint(model=compress_model)

initial_epoch = 0
if resume_training:
initial_epoch = resume_from_checkpoint(checkpoint=checkpoint,
ckpt_path=config.ckpt_path,
steps_per_epoch=train_steps)
weights_path = config.get('weights', None)
if weights_path:
compress_model.load_weights(weights_path)
logger.info(f'Weights from {weights_path} were loaded successfully')

callbacks = get_callbacks(
include_tensorboard=True,
Expand All @@ -238,42 +294,46 @@ def run(config):

callbacks.append(get_progress_bar(
stateful_metrics=['loss'] + [metric.name for metric in metrics]))
callbacks.extend(compression_callbacks)
#callbacks.extend(compression_callbacks)

validation_kwargs = {
'validation_data': validation_dataset,
'validation_steps': validation_steps,
'validation_freq': config.test_every_n_epochs,
}

# BN INITIALIZATION
# Set trainable graph for eval
enable_bn = False
if not resume_training and enable_bn:
print(25*'*')
print('Start BN adaptiation')
print(25*'*')
compress_model.layers[0].training_forced = True
# Update BN statistics
compress_model.evaluate(train_dataset,
steps=1000,
callbacks=[get_progress_bar(
stateful_metrics=['loss'] + [metric.name for metric in metrics])],
verbose=1)
# Reset model
compress_model.layers[0].training_forced = None
compress_model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
run_eagerly=config.get('eager_mode', False))
###
if 'train' in config.mode:
if is_accuracy_aware_training(config):
logger.info('starting an accuracy-aware training loop...')
result_dict_to_val_metric_fn = lambda results: 100 * results['acc@1']
compress_model.accuracy_aware_fit(train_dataset,
compression_ctrl,
nncf_config=config.nncf_config,
callbacks=callbacks,
initial_epoch=initial_epoch,
steps_per_epoch=train_steps,
tensorboard_writer=config.tb,
log_dir=config.log_dir,
uncompressed_model_accuracy=uncompressed_model_accuracy,
result_dict_to_val_metric_fn=result_dict_to_val_metric_fn,
**validation_kwargs)
else:
logger.info('training...')
compress_model.fit(
train_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
**validation_kwargs)
logger.info('training...')
compress_model.fit(
train_dataset,
epochs=train_epochs,
steps_per_epoch=train_steps,
initial_epoch=initial_epoch,
callbacks=callbacks,
**validation_kwargs)

logger.info('evaluation...')
statistics = compression_ctrl.statistics()
logger.info(statistics.to_str())
results = compress_model.evaluate(
validation_dataset,
steps=validation_steps,
Expand All @@ -285,9 +345,10 @@ def run(config):
write_metrics(results[1], config.metrics_dump)

if 'export' in config.mode:
save_path, save_format = get_saving_parameters(config)
compression_ctrl.export_model(save_path, save_format)
logger.info('Saved to {}'.format(save_path))
save_model_as_frozen_graph(compress_model, config.to_frozen_graph)
#save_path, save_format = get_saving_parameters(config)
#compression_ctrl.export_model(save_path, save_format)
#logger.info('Saved to {}'.format(save_path))


def export(config):
Expand Down Expand Up @@ -329,6 +390,7 @@ def export(config):
def main(argv):
parser = get_argument_parser()
config = get_config_from_argv(argv, parser)
#config['eager_mode'] = True
print_args(config)

serialize_config(config.nncf_config, config.log_dir)
Expand Down
Loading