Skip to content

Commit 2e68fd3

Browse files
authored
[PyOV] Fix passing Model as argument (openvinotoolkit#28896)
### Details: - `run_on_model` was broken due-to changes introduced by this PR openvinotoolkit#27191 - add more tests to cover such cases ### Tickets: - CVS-162131
1 parent fa3a158 commit 2e68fd3

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

src/bindings/python/src/pyopenvino/graph/passes/model_pass.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <string>
1111

1212
#include "pyopenvino/core/common.hpp"
13+
#include "pyopenvino/utils/utils.hpp"
1314

1415
namespace py = pybind11;
1516

@@ -34,10 +35,14 @@ void regclass_passes_ModelPass(py::module m) {
3435
"ModelPass");
3536
model_pass.doc() = "openvino.passes.ModelPass wraps ov::pass::ModelPass";
3637
model_pass.def(py::init<>());
37-
model_pass.def("run_on_model",
38-
&ov::pass::ModelPass::run_on_model,
39-
py::arg("model"),
40-
R"(
38+
model_pass.def(
39+
"run_on_model",
40+
[](ov::pass::ModelPass& self, const py::object& ie_api_model) {
41+
const auto model = Common::utils::convert_to_model(ie_api_model);
42+
self.run_on_model(model);
43+
},
44+
py::arg("model"),
45+
R"(
4146
run_on_model must be defined in inherited class. This method is used to work with Model directly.
4247
4348
:param model: openvino.Model to be transformed.

src/bindings/python/tests/test_transformations/test_model_pass.py

+7
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,10 @@ def test_model_pass():
1212
manager.run_passes(get_relu_model())
1313

1414
assert model_pass.model_changed
15+
16+
17+
def test_model_pass_run_on_model():
18+
model_pass = MyModelPass()
19+
model_pass.run_on_model(get_relu_model())
20+
21+
assert model_pass.model_changed

src/bindings/python/tests/test_transformations/test_public_transformations.py

+11
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,14 @@ def test_serialize_pass(request, tmp_path, is_path_xml, is_path_bin):
153153

154154
os.remove(xml_path)
155155
os.remove(bin_path)
156+
157+
158+
@pytest.mark.parametrize(("transformation", "arguments"),
159+
[(ConstantFolding, []),
160+
(MakeStateful, [{"parameter": "result"}]),
161+
(ConvertFP32ToFP16, []),
162+
(LowLatency2, [])])
163+
def test_run_on_model_transformations(transformation, arguments):
164+
model = get_model()
165+
transformation(*arguments).run_on_model(model)
166+
assert model is not None

0 commit comments

Comments
 (0)