Skip to content

Commit 17d4285

Browse files
added inputs check
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent b284bc9 commit 17d4285

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

optimum/intel/openvino/modeling.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -994,11 +994,18 @@ class OVModelForCustomTasks(OVModel):
994994
)
995995
)
996996
def forward(self, **kwargs):
997-
np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray)
997+
expected_inputs_names = set(self.input_names)
998+
inputs_names = set(kwargs)
998999

1000+
if not expected_inputs_names.issubset(inputs_names):
1001+
raise ValueError(
1002+
f"Got unexpected inputs: expecting the following inputs : {', '.join(expected_inputs_names)} but got : {', '.join(inputs_names)}."
1003+
)
1004+
1005+
np_inputs = isinstance(next(iter(kwargs.values())), np.ndarray)
9991006
inputs = {}
1000-
for key, value in kwargs.items():
1001-
inputs[key] = np.array(value) if not np_inputs else value
1007+
for input_name in self.input_names:
1008+
inputs[input_name] = np.array(kwargs.pop(input_name)) if not np_inputs else kwargs.pop(input_name)
10021009

10031010
outputs = self.request(inputs)
10041011

0 commit comments

Comments
 (0)