Skip to content

Commit b816d77

Browse files
authored
Optimize the Workflow of Parsing Keras Model (#1623)
Signed-off-by: zehao-intel <zehao.huang@intel.com>
1 parent f2d9b78 commit b816d77

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

neural_compressor/model/tensorflow_model.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -313,17 +313,10 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
313313
def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names):
314314
from tensorflow.python.saved_model import signature_constants, tag_constants
315315

316-
from neural_compressor.adaptor.tf_utils.util import parse_saved_model
317-
318316
saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
319317
saved_model_tags = set([tag_constants.SERVING])
320-
try:
321-
graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
322-
saved_model_dir, True, input_tensor_names, output_tensor_names
323-
)
324-
except:
325-
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)
326-
return graph_def, input_names, output_names
318+
319+
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)
327320

328321

329322
def _get_graph_from_original_keras_v2(model, output_dir):
@@ -467,6 +460,15 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
467460
try:
468461
tf.keras.backend.set_learning_phase(0)
469462
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
463+
except:
464+
keras_format = "saved_model_general"
465+
if keras_format == "saved_model_general": # pargma: no cover
466+
try:
467+
from neural_compressor.adaptor.tf_utils.util import parse_saved_model
468+
469+
graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
470+
temp_dir, True, input_tensor_names, output_tensor_names
471+
)
470472
except:
471473
raise ValueError("Not supported keras model type...")
472474

0 commit comments

Comments
 (0)