Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] model_split use _from_proto_fn instead of get_tensor_by_name for TF1 #519

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/quick_start/local_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install tensorflow_probability==0.5.0
常见版本对应关系:

| TensorFlow版本 | TensorFlowProbability版本 |
|--------------|-------------------------|
| ------------ | ----------------------- |
| 1.12 | 0.5.0 |
| 1.15 | 0.8.0 |
| 2.5.0 | 0.13.0 |
Expand Down
7 changes: 2 additions & 5 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
Expand All @@ -18,7 +19,6 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
from tensorflow.python.saved_model.path_helpers import get_variables_path
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
else:
from tensorflow.python.saved_model.utils_impl import get_variables_path

Expand Down Expand Up @@ -207,10 +207,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
if tf.__version__ >= '2.0':
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
else:
variable = graph.get_tensor_by_name(name)
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
Expand Down