Skip to content

Commit eab6b46

Browse files
Update test_examples (#3195)
### Changes Add quantization_aware_training_tensorflow_mobilenet_v2 to test scope
1 parent f574a1f commit eab6b46

File tree

4 files changed

+38
-4
lines changed

4 files changed

+38
-4
lines changed

examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def preprocess_for_train(image, label):
113113

114114

115115
train_dataset = tfds.load("imagenette/320px-v2", split="train", shuffle_files=True, as_supervised=True)
116-
train_dataset = train_dataset.map(preprocess_for_train).shuffle(1024).batch(128)
116+
train_dataset = train_dataset.map(preprocess_for_train).batch(64)
117117

118118
val_dataset = tfds.load("imagenette/320px-v2", split="validation", shuffle_files=False, as_supervised=True)
119119
val_dataset = val_dataset.map(preprocess_for_eval).batch(128)
@@ -150,12 +150,15 @@ def transform_fn(data_item):
150150
tf_quantized_model = nncf.quantize(tf_model, calibration_dataset)
151151

152152
tf_quantized_model.compile(
153-
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
153+
optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
154154
loss=tf.keras.losses.CategoricalCrossentropy(),
155155
metrics=[tf.keras.metrics.CategoricalAccuracy()],
156156
)
157157

158-
tf_quantized_model.fit(train_dataset, epochs=3, verbose=1)
158+
# To minimize the example's runtime, we train for only 1 epoch. This is sufficient to demonstrate
159+
# that the quantized model produced by QAT is more accurate than the one produced by PTQ.
160+
# However, training for more than 1 epoch would further improve the quantized model's accuracy.
161+
tf_quantized_model.fit(train_dataset, epochs=1, verbose=1)
159162

160163
# Removes auxiliary layers and operations added during the quantization process,
161164
# resulting in a clean, fully quantized model ready for deployment.

tests/cross_fw/examples/.test_durations

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_anomalib]": 478.797,
1515
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_resnet18]": 1251.144,
1616
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243,
17-
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69
17+
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69,
18+
"tests.cross_fw.examples.test_examples.test_examples[quantization_aware_training_tensorflow_mobilenet_v2]": 1500.00
1819
}

tests/cross_fw/examples/example_scope.json

+21
Original file line numberDiff line numberDiff line change
@@ -273,5 +273,26 @@
273273
"Tokyo."
274274
]
275275
}
276+
},
277+
"quantization_aware_training_tensorflow_mobilenet_v2": {
278+
"backend": "tf",
279+
"requirements": "examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt",
280+
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
281+
"accuracy_tolerance": 0.003,
282+
"accuracy_metrics": {
283+
"fp32_top1": 0.987770676612854,
284+
"int8_top1": 0.9737579822540283,
285+
"accuracy_drop": 0.014012694358825684
286+
},
287+
"performance_metrics": {
288+
"fp32_fps": 1703.04,
289+
"int8_fps": 5796.3,
290+
"performance_speed_up": 3.403501972942503
291+
},
292+
"model_size_metrics": {
293+
"fp32_model_size": 8.596238136291504,
294+
"int8_model_size": 2.69466495513916,
295+
"model_compression_rate": 3.1900953474371994
296+
}
276297
}
277298
}

tests/cross_fw/examples/run_example.py

+9
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,15 @@ def quantization_aware_training_torch_anomalib(data: Union[str, None]):
279279
}
280280

281281

282+
def quantization_aware_training_tensorflow_mobilenet_v2() -> Dict[str, float]:
283+
import tensorflow_datasets as tfds
284+
285+
tfds.display_progress_bar(enable=False)
286+
287+
example_root = str(PROJECT_ROOT / "examples" / "quantization_aware_training" / "tensorflow" / "mobilenet_v2")
288+
return post_training_quantization_mobilenet_v2(example_root)
289+
290+
282291
def main(argv):
283292
parser = ArgumentParser()
284293
parser.add_argument("--name", help="Example name", required=True)

0 commit comments

Comments
 (0)