6
6
7
7
8
8
import logging as log
9
+ import numpy as np
9
10
import sys
11
+ from openvino .runtime import PartialShape , Dimension , Type
10
12
from packaging .version import parse , Version
11
13
from typing import List , Dict , Union
12
14
13
- import numpy as np
14
- from openvino .runtime import PartialShape , Dimension , Type
15
-
16
15
17
16
# TODO: reuse this method in ovc and remove duplication
18
17
def get_static_shape (shape : [PartialShape , list , tuple ], dynamic_value = None ):
@@ -106,13 +105,32 @@ def trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_t
106
105
return trace_tf_model (input_model , placeholder_shapes , placeholder_data_types , example_input )
107
106
108
107
109
- def get_input_spec_from_model (model ):
108
+ def partial_shape_to_list (partial_shape : PartialShape ):
109
+ if partial_shape .rank .is_dynamic :
110
+ return None
111
+ res_list = []
112
+ for dim in partial_shape :
113
+ if dim .is_static :
114
+ res_list .append (dim .get_length ())
115
+ else :
116
+ res_list .append (None )
117
+ return res_list
118
+
119
+
120
+ def get_input_spec_from_model (model , input_shapes = None ):
110
121
import tensorflow as tf
111
122
if hasattr (model , "_build_input_shape" ) and model ._build_input_shape is not None :
112
123
if isinstance (model ._build_input_shape , list ):
113
124
input_spec = [[tf .TensorSpec (shape ) for shape in model ._build_input_shape ]]
114
125
else :
115
126
input_spec = [tf .TensorSpec (model ._build_input_shape )]
127
+ elif input_shapes and isinstance (input_shapes , list ) and len (input_shapes ) > 0 :
128
+ input_spec = []
129
+ for input_shape in input_shapes :
130
+ if isinstance (input_shape , PartialShape ):
131
+ input_spec .append (tf .TensorSpec (partial_shape_to_list (input_shape )))
132
+ else :
133
+ input_spec .append (tf .TensorSpec (None ))
116
134
else :
117
135
input_spec = [tf .TensorSpec (None )]
118
136
return input_spec
@@ -199,10 +217,13 @@ def create_generic_function_from_keras_model(keras_model):
199
217
if tf_input_signature is not None :
200
218
@tf .function (input_signature = tf_input_signature )
201
219
def wrapper_function_dict (* args ):
202
- input_dict = {}
203
- for ind , tensor_spec in enumerate (tf_input_signature ):
204
- input_dict [tensor_spec .name ] = args [ind ]
205
- outputs = keras_model (input_dict )
220
+ if isinstance (keras_input_signature , list ):
221
+ outputs = keras_model (args )
222
+ else :
223
+ input_dict = {}
224
+ for ind , tensor_spec in enumerate (tf_input_signature ):
225
+ input_dict [tensor_spec .name ] = args [ind ]
226
+ outputs = keras_model (input_dict )
206
227
# need to wrap the output into dictionary
207
228
# it helps to preserve original keras tensor names
208
229
post_outputs = {}
@@ -276,7 +297,7 @@ def are_shapes_defined(shape: Union[List, Dict]):
276
297
"Could not trace the TF model with the following error: {}" ,
277
298
use_example_input = False )
278
299
else :
279
- input_spec = get_input_spec_from_model (model )
300
+ input_spec = get_input_spec_from_model (model , input_shapes )
280
301
concrete_func = get_concrete_func (tf_function , input_spec , input_needs_packing ,
281
302
"Could not trace the TF model with the following error: {}.\n "
282
303
"Please provide 'example_input'." )
@@ -457,4 +478,4 @@ def tf_type_to_ov_type(val):
457
478
}
458
479
if val not in tf_to_ov_type :
459
480
raise Exception ("The provided data type is not supported by OpenVino {}." .format (val ))
460
- return tf_to_ov_type [val ]
481
+ return tf_to_ov_type [val ]
0 commit comments