Skip to content

Commit

Permalink
use _from_proto_fn instead of get_tensor_by_name (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng authored Jan 22, 2025
1 parent f175907 commit bc38227
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/quick_start/local_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip install tensorflow_probability==0.5.0
常见版本对应关系:

| TensorFlow版本 | TensorFlowProbability版本 |
|--------------|-------------------------|
| ------------ | ----------------------- |
| 1.12 | 0.5.0 |
| 1.15 | 0.8.0 |
| 2.5.0 | 0.13.0 |
Expand Down
7 changes: 2 additions & 5 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework.dtypes import _TYPE_TO_STRING
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
Expand All @@ -18,7 +19,6 @@
if tf.__version__ >= '2.0':
tf = tf.compat.v1
from tensorflow.python.saved_model.path_helpers import get_variables_path
from tensorflow.python.ops.resource_variable_ops import _from_proto_fn
else:
from tensorflow.python.saved_model.utils_impl import get_variables_path

Expand Down Expand Up @@ -207,10 +207,7 @@ def export(model_dir, meta_graph_def, variable_protos, input_tensor_names,
graph = ops.get_default_graph()
importer.import_graph_def(inference_graph, name='')
for name in variables_to_keep:
if tf.__version__ >= '2.0':
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
else:
variable = graph.get_tensor_by_name(name)
variable = _from_proto_fn(variable_protos[name.split(':')[0]])
graph.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, variable)
saver = tf_saver.Saver()
saver.restore(sess, get_variables_path(model_dir))
Expand Down

0 comments on commit bc38227

Please sign in to comment.