Skip to content

Commit 7261204

Browse files
305 python api (openvinotoolkit#938)
1 parent bc033ea commit 7261204

File tree

1 file changed

+47
-63
lines changed

1 file changed

+47
-63
lines changed

notebooks/305-tensorflow-quantization-aware-training/305-tensorflow-quantization-aware-training.ipynb

+47-63
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
"from nncf.tensorflow.helpers.model_creation import create_compressed_model\n",
5050
"from nncf.tensorflow.initialization import register_default_init_args\n",
5151
"from nncf.common.logging.logger import set_log_level\n",
52+
"from openvino.runtime import serialize\n",
53+
"from openvino.tools import mo\n",
5254
"\n",
5355
"set_log_level(logging.ERROR)\n",
5456
"\n",
@@ -60,10 +62,8 @@
6062
"BASE_MODEL_NAME = \"ResNet-18\"\n",
6163
"\n",
6264
"fp32_h5_path = Path(MODEL_DIR / (BASE_MODEL_NAME + \"_fp32\")).with_suffix(\".h5\")\n",
63-
"fp32_sm_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + \"_fp32\"))\n",
6465
"fp32_ir_path = Path(OUTPUT_DIR / \"saved_model\").with_suffix(\".xml\")\n",
6566
"int8_pb_path = Path(OUTPUT_DIR / (BASE_MODEL_NAME + \"_int8\")).with_suffix(\".pb\")\n",
66-
"int8_pb_name = Path(BASE_MODEL_NAME + \"_int8\").with_suffix(\".pb\")\n",
6767
"int8_ir_path = int8_pb_path.with_suffix(\".xml\")\n",
6868
"\n",
6969
"BATCH_SIZE = 128\n",
@@ -222,7 +222,7 @@
222222
"outputs": [],
223223
"source": [
224224
"IMG_SHAPE = IMG_SIZE + (3,)\n",
225-
"model = ResNet18(input_shape=IMG_SHAPE)"
225+
"fp32_model = ResNet18(input_shape=IMG_SHAPE)"
226226
]
227227
},
228228
{
@@ -245,37 +245,22 @@
245245
"outputs": [],
246246
"source": [
247247
"# Load the floating-point weights.\n",
248-
"model.load_weights(fp32_h5_path)\n",
248+
"fp32_model.load_weights(fp32_h5_path)\n",
249249
"\n",
250250
"# Compile the floating-point model.\n",
251-
"model.compile(loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),\n",
252-
" metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')])\n",
251+
"fp32_model.compile(\n",
252+
" loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),\n",
253+
" metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')]\n",
254+
")\n",
253255
"\n",
254256
"# Validate the floating-point model.\n",
255-
"test_loss, acc_fp32 = model.evaluate(validation_dataset,\n",
256-
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))\n",
257+
"test_loss, acc_fp32 = fp32_model.evaluate(\n",
258+
" validation_dataset,\n",
259+
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1'])\n",
260+
")\n",
257261
"print(f\"\\nAccuracy of FP32 model: {acc_fp32:.3f}\")"
258262
]
259263
},
260-
{
261-
"cell_type": "markdown",
262-
"id": "b80f67d6",
263-
"metadata": {},
264-
"source": [
265-
"Save the floating-point model to the saved model, which will be later used for conversion to OpenVINO IR and further performance measurement."
266-
]
267-
},
268-
{
269-
"cell_type": "code",
270-
"execution_count": null,
271-
"id": "450cbcb2",
272-
"metadata": {},
273-
"outputs": [],
274-
"source": [
275-
"model.save(fp32_sm_path)\n",
276-
"print(f'Absolute path where the model is saved:\\n {fp32_sm_path.resolve()}')"
277-
]
278-
},
279264
{
280265
"cell_type": "markdown",
281266
"id": "13b81167",
@@ -346,7 +331,7 @@
346331
"metadata": {},
347332
"outputs": [],
348333
"source": [
349-
"compression_ctrl, model = create_compressed_model(model, nncf_config)"
334+
"compression_ctrl, int8_model = create_compressed_model(fp32_model, nncf_config)"
350335
]
351336
},
352337
{
@@ -365,13 +350,17 @@
365350
"outputs": [],
366351
"source": [
367352
"# Compile the INT8 model.\n",
368-
"model.compile(optimizer=tf.keras.optimizers.Adam(lr=LR),\n",
369-
" loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),\n",
370-
" metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')])\n",
353+
"int8_model.compile(\n",
354+
" optimizer=tf.keras.optimizers.Adam(lr=LR),\n",
355+
" loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),\n",
356+
" metrics=[tf.keras.metrics.CategoricalAccuracy(name='acc@1')]\n",
357+
")\n",
371358
"\n",
372359
"# Validate the INT8 model.\n",
373-
"test_loss, test_acc = model.evaluate(validation_dataset,\n",
374-
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))\n",
360+
"test_loss, test_acc = int8_model.evaluate(\n",
361+
" validation_dataset,\n",
362+
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1'])\n",
363+
")\n",
375364
"print(f\"\\nAccuracy of INT8 model after initialization: {test_acc:.3f}\")"
376365
]
377366
},
@@ -393,53 +382,38 @@
393382
"scrolled": true,
394383
"tags": [],
395384
"test_replace": {
396-
"fit(train_dataset,": "fit(validation_dataset,"
385+
"train_dataset,": "validation_dataset,"
397386
}
398387
},
399388
"outputs": [],
400389
"source": [
401390
"# Train the INT8 model.\n",
402-
"model.fit(train_dataset,\n",
403-
" epochs=2)\n",
391+
"int8_model.fit(\n",
392+
" train_dataset,\n",
393+
" epochs=2\n",
394+
")\n",
404395
"\n",
405396
"# Validate the INT8 model.\n",
406-
"test_loss, acc_int8 = model.evaluate(validation_dataset,\n",
407-
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1']))\n",
397+
"test_loss, acc_int8 = int8_model.evaluate(\n",
398+
" validation_dataset,\n",
399+
" callbacks=tf.keras.callbacks.ProgbarLogger(stateful_metrics=['acc@1'])\n",
400+
")\n",
408401
"print(f\"\\nAccuracy of INT8 model after fine-tuning: {acc_int8:.3f}\")\n",
409402
"print(f\"\\nAccuracy drop of tuned INT8 model over pre-trained FP32 model: {acc_fp32 - acc_int8:.3f}\")"
410403
]
411404
},
412-
{
413-
"cell_type": "markdown",
414-
"id": "7af453ef",
415-
"metadata": {},
416-
"source": [
417-
"Save the `INT8` model to the frozen graph (saved model does not work with quantized model for now). Frozen graph will be later used for conversion to OpenVINO IR and further performance measurement."
418-
]
419-
},
420-
{
421-
"cell_type": "code",
422-
"execution_count": null,
423-
"id": "6b208b6c",
424-
"metadata": {},
425-
"outputs": [],
426-
"source": [
427-
"compression_ctrl.export_model(int8_pb_path, 'frozen_graph')\n",
428-
"print(f'Absolute path where the int8 model is saved:\\n {int8_pb_path.resolve()}')"
429-
]
430-
},
431405
{
432406
"cell_type": "markdown",
433407
"id": "1248a563",
434408
"metadata": {},
435409
"source": [
436-
"## Export Frozen Graph Models to OpenVINO Intermediate Representation (IR)\n",
410+
"## Export Models to OpenVINO Intermediate Representation (IR)\n",
437411
"\n",
438-
"Use Model Optimizer to convert the Saved Model and Frozen Graph models to OpenVINO IR. The models are saved to the current directory.\n",
412+
"Use Model Optimizer Python API to convert the models to OpenVINO IR.\n",
439413
"\n",
440-
"For more information about Model Optimizer, see the [Model Optimizer Developer Guide](https://docs.openvino.ai/latest/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html).\n",
414+
"For more information about Model Optimizer, see the [Model Optimizer Developer Guide](https://docs.openvino.ai/latest/openvino_docs_MO_DG_Python_API.html).\n",
441415
"\n",
442-
"Executing this command may take a while. There may be some errors or warnings in the output. When Model Optimization successfully exports to OpenVINO IR, the last lines of the output will include: `[ SUCCESS ] Generated IR version 11 model`"
416+
"Executing this command may take a while."
443417
]
444418
},
445419
{
@@ -449,7 +423,10 @@
449423
"metadata": {},
450424
"outputs": [],
451425
"source": [
452-
"!mo --input_shape=\"[1,64,64,3]\" --input=data --saved_model_dir=$fp32_sm_path --output_dir=$OUTPUT_DIR"
426+
"model_ir_fp32 = mo.convert_model(\n",
427+
" fp32_model,\n",
428+
" input_shape=[1, 64, 64, 3],\n",
429+
")"
453430
]
454431
},
455432
{
@@ -459,7 +436,10 @@
459436
"metadata": {},
460437
"outputs": [],
461438
"source": [
462-
"!mo --input_shape=\"[1,64,64,3]\" --input=Placeholder --input_model=$int8_pb_path --output_dir=$OUTPUT_DIR"
439+
"model_ir_int8 = mo.convert_model(\n",
440+
" int8_model,\n",
441+
" input_shape=[1, 64, 64, 3],\n",
442+
")"
463443
]
464444
},
465445
{
@@ -483,6 +463,10 @@
483463
},
484464
"outputs": [],
485465
"source": [
466+
"serialize(model_ir_fp32, str(fp32_ir_path))\n",
467+
"serialize(model_ir_int8, str(int8_ir_path))\n",
468+
"\n",
469+
"\n",
486470
"def parse_benchmark_output(benchmark_output):\n",
487471
" parsed_output = [line for line in benchmark_output if 'FPS' in line]\n",
488472
" print(*parsed_output, sep='\\n')\n",

0 commit comments

Comments
 (0)