Skip to content

Commit cfeadc0

Browse files
authored
Serialize model (#37)
* Serialize model * Add tests/python/precommit/test_serialize.py * isort style * Replace [] with '' * Reuse model.parameters() * Reuse tmp_path * serialize->save * black style * ssd_mobilenet_fpn.serialize->ssd_mobilenet_fpn.save
1 parent 126d870 commit cfeadc0

File tree

8 files changed

+80
-21
lines changed

8 files changed

+80
-21
lines changed

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
# mask.png is generated by examples/python/simple_local/run.py
1+
# Generated by examples/python/synchronous_api/run.py
22
mask.png
33
ssd_mobilenet_v1_fpn_coco_with_preprocessing.xml
4+
ssd_mobilenet_v1_fpn_coco_with_preprocessing.bin
5+
tmp/
6+
# Generated by tests/python/accuracy/test_accuracy.py
7+
test_scope.json
48

59
# Byte-compiled / optimized / DLL files
610
__pycache__/

examples/python/synchronous_api/run.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ def main():
5050
)
5151
detections = ssd_mobilenet_fpn(image)
5252
print(f"Detection results: {detections}")
53-
ov.serialize(
54-
ssd_mobilenet_fpn.get_model(),
55-
"ssd_mobilenet_v1_fpn_coco_with_preprocessing.xml",
56-
)
53+
ssd_mobilenet_fpn.save("ssd_mobilenet_v1_fpn_coco_with_preprocessing.xml")
5754

5855
# Instantiate from a local model (downloaded previously)
5956
ssd_mobilenet_fpn_local = DetectionModel.create_model(

model_api/python/openvino/model_api/adapters/openvino_adapter.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
try:
2222
import openvino.runtime as ov
23-
from openvino.preprocess import ColorFormat, PrePostProcessor, ResizeAlgorithm
23+
from openvino.preprocess import ColorFormat, PrePostProcessor
2424
from openvino.runtime import (
2525
AsyncInferQueue,
2626
Core,
@@ -131,7 +131,7 @@ def __init__(
131131
self,
132132
core,
133133
model,
134-
weights_path=None,
134+
weights_path="",
135135
model_parameters={},
136136
device="CPU",
137137
plugin_config=None,
@@ -170,8 +170,7 @@ def __init__(
170170
"from buffer" if self.model_from_buffer else self.model_path
171171
)
172172
)
173-
weights = weights_path if self.model_from_buffer else ""
174-
self.model = core.read_model(self.model_path, weights)
173+
self.model = core.read_model(self.model_path, weights_path)
175174
return
176175
if isinstance(model, str):
177176
from openvino.model_zoo.models import OMZModel, list_models
@@ -409,10 +408,10 @@ def embed_preprocessing(
409408
self.load_model()
410409

411410
def get_model(self):
412-
"""Returns the ov.Model object
411+
"""Returns the openvino.runtime.Model object
413412
414413
Returns:
415-
ov.Model object
414+
openvino.runtime.Model object
416415
"""
417416
return self.model
418417

model_api/python/openvino/model_api/adapters/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121
import openvino.runtime as ov
22-
from openvino.runtime import Output, layout_helpers
22+
from openvino.runtime import Output, Type, layout_helpers
2323
from openvino.runtime import opset10 as opset
2424
from openvino.runtime.utils.decorators import custom_preprocess_function
2525

@@ -101,8 +101,8 @@ def resize_image_letterbox_graph(input: Output, size, interpolation="linear"):
101101
opset.gather(image_shape, opset.constant(h_axis), axis=0),
102102
destination_type="f32",
103103
)
104-
w_ratio = opset.divide(opset.constant(w, dtype=float), iw)
105-
h_ratio = opset.divide(opset.constant(h, dtype=float), ih)
104+
w_ratio = opset.divide(opset.constant(w, dtype=Type.f32), iw)
105+
h_ratio = opset.divide(opset.constant(h, dtype=Type.f32), ih)
106106
scale = opset.minimum(w_ratio, h_ratio)
107107
nw = opset.convert(opset.multiply(iw, scale), destination_type="i32")
108108
nh = opset.convert(opset.multiply(ih, scale), destination_type="i32")

model_api/python/openvino/model_api/models/image_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ def parameters(cls):
8989
parameters.update(
9090
{
9191
"mean_values": ListValue(
92-
default_value=None,
92+
default_value=[],
9393
description="Normalization values, which will be subtracted from image channels for image-input layer during preprocessing",
9494
),
9595
"scale_values": ListValue(
96-
default_value=None,
96+
default_value=[],
9797
description="Normalization values, which will divide the image channels for image-input layer",
9898
),
9999
"reverse_input_channels": BooleanValue(

model_api/python/openvino/model_api/models/model.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def create_model(
124124
configuration={},
125125
preload=True,
126126
core=None,
127-
weights_path=None,
127+
weights_path="",
128128
adaptor_parameters={},
129129
device="AUTO",
130130
nstreams="1",
@@ -469,3 +469,20 @@ def log_layers_info(self):
469469
name, metadata.shape, metadata.precision, metadata.layout
470470
)
471471
)
472+
473+
def get_model(self):
474+
model = self.inference_adapter.get_model()
475+
model.set_rt_info(self.__model__, ["model_info", "model_type"])
476+
for name in self.parameters():
477+
if [] == getattr(self, name):
478+
# ov cant serialize empty list. Replace it with ""
479+
# TODO: remove when Anastasia Kuporosova fixes that
480+
model.set_rt_info("", ["model_info", name])
481+
else:
482+
model.set_rt_info(getattr(self, name), ["model_info", name])
483+
return model
484+
485+
def save(self, xml_path, bin_path="", version="UNSPECIFIED"):
486+
import openvino.runtime as ov
487+
488+
ov.serialize(self.get_model(), xml_path, bin_path, version)

model_api/python/openvino/model_api/models/types.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def __str__(self) -> str:
103103

104104

105105
class StringValue(BaseValue):
106-
def __init__(self, choices=(), **kwargs):
107-
super().__init__(**kwargs)
106+
def __init__(
107+
self, choices=(), description="No description available", default_value=""
108+
):
109+
super().__init__(description, default_value)
108110
self.choices = choices
109111
for choice in self.choices:
110112
if not isinstance(choice, str):
@@ -161,8 +163,10 @@ def validate(self, value):
161163

162164

163165
class ListValue(BaseValue):
164-
def __init__(self, value_type=None, **kwargs) -> None:
165-
super().__init__(**kwargs)
166+
def __init__(
167+
self, value_type=None, description="No description available", default_value=[]
168+
) -> None:
169+
super().__init__(description, default_value)
166170
self.value_type = value_type
167171

168172
def from_str(self, value):

tests/python/precommit/test_save.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from openvino.model_api.models import Model
2+
3+
4+
def test_detector_save(tmp_path):
5+
downloaded = Model.create_model(
6+
"ssd300", configuration={"mean_values": [0, 0, 0], "confidence_threshold": 0.6}
7+
)
8+
xml_path = str(tmp_path / "a.xml")
9+
downloaded.save(xml_path)
10+
deserialized = Model.create_model(xml_path)
11+
assert type(downloaded) == type(deserialized)
12+
for attr in downloaded.parameters():
13+
assert getattr(downloaded, attr) == getattr(deserialized, attr)
14+
15+
16+
def test_classifier_save(tmp_path):
17+
downloaded = Model.create_model(
18+
"efficientnet-b0-pytorch", configuration={"scale_values": [1, 1, 1], "topk": 6}
19+
)
20+
xml_path = str(tmp_path / "a.xml")
21+
downloaded.save(xml_path)
22+
deserialized = Model.create_model(xml_path)
23+
assert type(downloaded) == type(deserialized)
24+
for attr in downloaded.parameters():
25+
assert getattr(downloaded, attr) == getattr(deserialized, attr)
26+
27+
28+
def test_segmentor_save(tmp_path):
29+
downloaded = Model.create_model(
30+
"hrnet-v2-c1-segmentation",
31+
configuration={"reverse_input_channels": True, "labels": ["first", "second"]},
32+
)
33+
xml_path = str(tmp_path / "a.xml")
34+
downloaded.save(xml_path)
35+
deserialized = Model.create_model(xml_path)
36+
assert type(downloaded) == type(deserialized)
37+
for attr in downloaded.parameters():
38+
assert getattr(downloaded, attr) == getattr(deserialized, attr)

0 commit comments

Comments
 (0)