File tree 1 file changed +10
-3
lines changed
1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -994,11 +994,18 @@ class OVModelForCustomTasks(OVModel):
994
994
)
995
995
)
996
996
def forward (self , ** kwargs ):
997
- np_inputs = isinstance (next (iter (kwargs .values ())), np .ndarray )
997
+ expected_inputs_names = set (self .input_names )
998
+ inputs_names = set (kwargs )
998
999
1000
+ if not expected_inputs_names .issubset (inputs_names ):
1001
+ raise ValueError (
1002
+ f"Got unexpected inputs: expecting the following inputs : { ', ' .join (expected_inputs_names )} but got : { ', ' .join (inputs_names )} ."
1003
+ )
1004
+
1005
+ np_inputs = isinstance (next (iter (kwargs .values ())), np .ndarray )
999
1006
inputs = {}
1000
- for key , value in kwargs . items () :
1001
- inputs [key ] = np .array (value ) if not np_inputs else value
1007
+ for input_name in self . input_names :
1008
+ inputs [input_name ] = np .array (kwargs . pop ( input_name )) if not np_inputs else kwargs . pop ( input_name )
1002
1009
1003
1010
outputs = self .request (inputs )
1004
1011
You can’t perform that action at this time.
0 commit comments