diff --git a/docs/source/quick_start/local_tutorial.md b/docs/source/quick_start/local_tutorial.md index b3a75fb00..adb4bb972 100644 --- a/docs/source/quick_start/local_tutorial.md +++ b/docs/source/quick_start/local_tutorial.md @@ -24,7 +24,7 @@ pip install tensorflow_probability==0.5.0 常见版本对应关系: | TensorFlow版本 | TensorFlowProbability版本 | -|--------------|-------------------------| +| ------------ | ----------------------- | | 1.12 | 0.5.0 | | 1.15 | 0.8.0 | | 2.5.0 | 0.13.0 | diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index ded5f0bf4..d86791708 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -9,6 +9,7 @@ from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework.dtypes import _TYPE_TO_STRING +from tensorflow.python.ops.resource_variable_ops import _from_proto_fn from tensorflow.python.saved_model import signature_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as tf_saver @@ -18,7 +19,6 @@ if tf.__version__ >= '2.0': tf = tf.compat.v1 from tensorflow.python.saved_model.path_helpers import get_variables_path - from tensorflow.python.ops.resource_variable_ops import _from_proto_fn else: from tensorflow.python.saved_model.utils_impl import get_variables_path @@ -207,10 +207,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names, graph = ops.get_default_graph() importer.import_graph_def(inference_graph, name='') for name in variables_to_keep: - if tf.__version__ >= '2.0': - variable = _from_proto_fn(variable_protos[name.split(':')[0]]) - else: - variable = graph.get_tensor_by_name(name) + variable = _from_proto_fn(variable_protos[name.split(':')[0]]) graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable) saver = tf_saver.Saver() saver.restore(sess, get_variables_path(model_dir))