Skip to content

Commit 4e45f8f

Browse files
authored
Improve UT Coverage for TF 3x (#1852)
Signed-off-by: zehao-intel <zehao.huang@intel.com> Signed-off-by: chensuyue <suyue.chen@intel.com>
1 parent 794b276 commit 4e45f8f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+556
-4011
lines changed

.azure-pipelines/scripts/ut/3x/coverage.3x_pt

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ branch = True
55
include =
66
*/neural_compressor/common/*
77
*/neural_compressor/torch/*
8+
omit =
9+
*/neural_compressor/torch/algorithms/habana_fp8/*
10+
*/neural_compressor/torch/amp/*
811
exclude_lines =
912
pragma: no cover
1013
raise NotImplementedError

.azure-pipelines/scripts/ut/3x/run_3x_tf.sh

+19-3
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,36 @@ inc_path=$(python -c 'import neural_compressor; print(neural_compressor.__path__
1616
cd /neural-compressor/test/3x || exit 1
1717
rm -rf torch
1818
rm -rf onnxrt
19-
rm -rf tensorflow/quantization/ptq/newapi
2019
mv tensorflow/keras ../3x_keras
21-
mv tensorflow/quantization/itex ./3x_itex
20+
mv tensorflow/quantization/ptq/newapi ../3x_newapi
2221

2322
LOG_DIR=/neural-compressor/log_dir
2423
mkdir -p ${LOG_DIR}
2524
ut_log_name=${LOG_DIR}/ut_3x_tf.log
25+
26+
# test for tensorflow ut
2627
pytest --cov="${inc_path}" -vs --disable-warnings --html=report_tf_quant.html --self-contained-html ./tensorflow/quantization 2>&1 | tee -a ${ut_log_name}
2728
rm -rf tensorflow/quantization
2829
pytest --cov="${inc_path}" --cov-append -vs --disable-warnings --html=report_tf.html --self-contained-html . 2>&1 | tee -a ${ut_log_name}
2930

31+
# test for tensorflow new api ut
32+
pip uninstall tensorflow -y
33+
pip install /tf_dataset/tf_binary/230928/tensorflow*.whl
34+
pip install cmake
35+
pip install protobuf==3.20.3
36+
pip install horovod==0.27.0
37+
pip list
38+
rm -rf tensorflow/*
39+
mkdir -p tensorflow/quantization/ptq
40+
mv ../3x_newapi tensorflow/quantization/ptq/newapi
41+
find . -name "test*.py" | sed "s,\.\/,python -m pytest --cov=${inc_path} --cov-append -vs --disable-warnings ,g" > run.sh
42+
cat run.sh
43+
bash run.sh 2>&1 | tee -a ${ut_log_name}
44+
45+
# test for itex ut
3046
rm -rf tensorflow/*
3147
mv ../3x_keras tensorflow/keras
32-
mv ../3x_itex tensorflow/quantization/itex
48+
pip uninstall tensorflow -y
3349
pip install intel-extension-for-tensorflow[cpu]
3450
pytest --cov="${inc_path}" --cov-append -vs --disable-warnings --html=report_keras.html --self-contained-html ./tensorflow 2>&1 | tee -a ${ut_log_name}
3551

.azure-pipelines/scripts/ut/3x/run_3x_tf_new_api.sh

-46
This file was deleted.

.azure-pipelines/ut-3x-tf.yml

-14
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,6 @@ stages:
4141
uploadPath: $(UPLOAD_PATH)
4242
utArtifact: "ut_3x"
4343

44-
- stage: NewTF
45-
displayName: Unit Test 3x New TF API
46-
dependsOn: []
47-
jobs:
48-
- job:
49-
displayName: Unit Test 3x New TF API
50-
steps:
51-
- template: template/ut-template.yml
52-
parameters:
53-
dockerConfigName: "commonDockerConfig"
54-
utScriptFileName: "3x/run_3x_tf_new_api"
55-
uploadPath: $(UPLOAD_PATH)
56-
utArtifact: "ut_3x_tf_new_api"
57-
5844
- stage: TensorFlow_baseline
5945
displayName: Unit Test 3x TensorFlow baseline
6046
dependsOn: []

.github/checkgroup.yml

-4
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@ subprojects:
5353
- "Model-Test (Run ONNX Model resnet50-v1-12)"
5454
- "Model-Test (Run PyTorch Model resnet18)"
5555
- "Model-Test (Run PyTorch Model resnet18_fx)"
56-
- "Model-Test (Run TensorFlow Model darknet19)"
57-
- "Model-Test (Run TensorFlow Model inception_v1)"
58-
- "Model-Test (Run TensorFlow Model resnet-101)"
5956
- "Model-Test (Run TensorFlow Model resnet50v1.5)"
60-
- "Model-Test (Run TensorFlow Model ssd_mobilenet_v1_ckpt)"
6157
- "Model-Test (Run TensorFlow Model ssd_resnet50_v1)"
6258

6359
- id: "Model Tests 3x workflow"

neural_compressor/tensorflow/algorithms/static_quant/keras.py

+6-125
Original file line numberDiff line numberDiff line change
@@ -90,46 +90,13 @@ def __init__(self, framework_specific_info):
9090
os.mkdir(DEFAULT_WORKSPACE)
9191
self.tmp_dir = (DEFAULT_WORKSPACE + "tmp_model.keras") if self.keras3 else (DEFAULT_WORKSPACE + "tmp_model")
9292

93-
def _check_itex(self):
94-
"""Check if the Intel® Extension for TensorFlow has been installed."""
95-
try:
96-
import intel_extension_for_tensorflow
97-
except:
98-
raise ImportError(
99-
"The Intel® Extension for TensorFlow is not installed. "
100-
"Please install it to run models on ITEX backend"
101-
)
102-
103-
def convert_bf16(self):
104-
"""Execute the BF16 conversion."""
105-
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
106-
model = self.pre_optimized_model
107-
108-
for layer in model.layers:
109-
if layer.name in self.bf16_ops:
110-
layer.dtype = "mixed_bfloat16"
111-
112-
model.save(self.tmp_dir)
113-
converted_model = tf.keras.models.load_model(self.tmp_dir)
114-
tf.keras.mixed_precision.set_global_policy("float32")
115-
116-
return converted_model
117-
118-
# (TODO) choose the properly quantize mode
119-
def _check_quantize_mode(self, model):
120-
"""Check what quantize mode to use."""
121-
for layer in model.layers:
122-
if "ReLU" in layer.__class__.__name__:
123-
return "MIN_FIRST"
124-
return "SCALED"
125-
12693
def _set_weights(self, qmodel, layer_weights):
12794
"""Set fp32 weights to qmodel."""
12895
for qlayer in qmodel.layers:
12996
if qlayer.get_weights():
13097
if qlayer.name in layer_weights:
13198
qlayer.set_weights(layer_weights[qlayer.name])
132-
else:
99+
else: # pragma: no cover
133100
hit_layer = False
134101
for sub_layer in qlayer.submodules:
135102
if sub_layer.name in layer_weights:
@@ -164,7 +131,7 @@ def _check_quantize_format(self, model):
164131
self.conv_format[layer.name] = "u8"
165132
break
166133

167-
def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers):
134+
def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers): # pragma: no cover
168135
fuse_layers = []
169136
fused_bn_name = ""
170137
for idx, layer in enumerate(fp32_layers):
@@ -211,7 +178,7 @@ def _fuse_bn_keras3(self, fuse_conv_bn, fp32_layers):
211178

212179
return fuse_layers
213180

214-
def _fuse_bn_keras2(self, fuse_conv_bn, fp32_layers):
181+
def _fuse_bn_keras2(self, fuse_conv_bn, fp32_layers): # pragma: no cover
215182
fuse_layers = []
216183
for idx, layer in enumerate(fp32_layers):
217184
if hasattr(layer, "_inbound_nodes"):
@@ -272,7 +239,7 @@ def _fuse_bn_keras2(self, fuse_conv_bn, fp32_layers):
272239

273240
return fuse_layers
274241

275-
def _fuse_bn(self, model):
242+
def _fuse_bn(self, model): # pragma: no cover
276243
"""Fusing Batch Normalization."""
277244
model.save(self.tmp_dir)
278245
fuse_bn_model = tf.keras.models.load_model(self.tmp_dir)
@@ -362,14 +329,6 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None):
362329
tune_cfg = converter.parse_to_tune_cfg()
363330
self.tuning_cfg_to_fw(tune_cfg)
364331

365-
# just convert the input model to mixed_bfloat16
366-
if self.bf16_ops and not self.quantize_config["op_wise_config"]:
367-
converted_model = self.convert_bf16()
368-
return converted_model
369-
370-
# if self.backend == "itex":
371-
# self._check_itex()
372-
373332
logger.debug("Dump quantization configurations:")
374333
logger.debug(self.quantize_config)
375334
calib_sampling_size = tune_cfg.get("calib_sampling_size", 1)
@@ -469,59 +428,6 @@ def _calibrate(self, model, dataloader, calib_interation):
469428

470429
return quantized_model
471430

472-
@dump_elapsed_time(customized_msg="Model inference")
473-
def evaluate(
474-
self,
475-
model,
476-
dataloader,
477-
postprocess=None,
478-
metrics=None,
479-
measurer=None,
480-
iteration=-1,
481-
fp32_baseline=False,
482-
):
483-
"""The function is used to run evaluation on validation dataset.
484-
485-
Args:
486-
model (object): The model to do calibration.
487-
dataloader (generator): generate the data and labels.
488-
postprocess (object, optional): process the result from the model
489-
metric (object, optional): Depends on model category. Defaults to None.
490-
measurer (object, optional): for precise benchmark measurement.
491-
iteration(int, optional): control steps of mini-batch
492-
fp32_baseline (boolean, optional): only for compare_label=False pipeline
493-
"""
494-
# use keras object
495-
keras_model = model.model
496-
logger.info("Start to evaluate the Keras model.")
497-
results = []
498-
for idx, (inputs, labels) in enumerate(dataloader):
499-
# use predict on batch
500-
if measurer is not None:
501-
measurer.start()
502-
predictions = keras_model.predict_on_batch(inputs)
503-
measurer.end()
504-
else:
505-
predictions = keras_model.predict_on_batch(inputs)
506-
507-
if self.fp32_preds_as_label:
508-
self.fp32_results.append(predictions) if fp32_baseline else results.append(predictions)
509-
510-
if postprocess is not None:
511-
predictions, labels = postprocess((predictions, labels))
512-
if metrics:
513-
for metric in metrics:
514-
if not hasattr(metric, "compare_label") or (
515-
hasattr(metric, "compare_label") and metric.compare_label
516-
):
517-
metric.update(predictions, labels)
518-
if idx + 1 == iteration:
519-
break
520-
521-
acc = 0 if metrics is None else [metric.result() for metric in metrics]
522-
523-
return acc if not isinstance(acc, list) or len(acc) > 1 else acc[0]
524-
525431
def query_fw_capability(self, model):
526432
"""The function is used to return framework tuning capability.
527433
@@ -621,7 +527,7 @@ def tuning_cfg_to_fw(self, tuning_cfg):
621527
for each_op_info in tuning_cfg["op"]:
622528
op_name = each_op_info[0]
623529

624-
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16":
530+
if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": # pragma: no cover
625531
if each_op_info[1] in bf16_type:
626532
bf16_ops.append(op_name)
627533
continue
@@ -693,31 +599,6 @@ def _get_specified_version_cfg(self, data):
693599

694600
return default_config
695601

696-
def get_version(self):
697-
"""Get the current backend version information.
698-
699-
Returns:
700-
[string]: version string.
701-
"""
702-
return self.cur_config["version"]["name"]
703-
704-
def get_precisions(self):
705-
"""Get supported precisions for current backend.
706-
707-
Returns:
708-
[string list]: the precisions' name.
709-
"""
710-
return self.cur_config["precisions"]["names"]
711-
712-
def get_op_types(self):
713-
"""Get the supported op types by all precisions.
714-
715-
Returns:
716-
[dictionary list]: A list composed of dictionary which key is precision
717-
and value is the op types.
718-
"""
719-
return self.cur_config["ops"]
720-
721602
def get_quantization_capability(self):
722603
"""Get the supported op types' quantization capability.
723604
@@ -846,7 +727,7 @@ def _parse_inputs(self, BN_fused_layers=None, conv_names=None):
846727

847728
try:
848729
model_input = self.model.input
849-
except ValueError:
730+
except ValueError: # pragma: no cover
850731
model_input = self.model.inputs[0]
851732

852733
return input_layer_dict, model_input

0 commit comments

Comments
 (0)