forked from openvinotoolkit/model_api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_save.py
48 lines (42 loc) · 1.77 KB
/
test_save.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from openvino.model_api.models import Model
def test_detector_save(tmp_path):
downloaded = Model.create_model(
"ssd_mobilenet_v1_fpn_coco",
configuration={"mean_values": [0, 0, 0], "confidence_threshold": 0.6},
)
assert True == downloaded.get_model().get_rt_info(
["model_info", "embedded_processing"]
).astype(bool)
xml_path = str(tmp_path / "a.xml")
downloaded.save(xml_path)
deserialized = Model.create_model(xml_path)
assert type(downloaded) == type(deserialized)
for attr in downloaded.parameters():
assert getattr(downloaded, attr) == getattr(deserialized, attr)
def test_classifier_save(tmp_path):
downloaded = Model.create_model(
"efficientnet-b0-pytorch", configuration={"scale_values": [1, 1, 1], "topk": 6}
)
assert True == downloaded.get_model().get_rt_info(
["model_info", "embedded_processing"]
).astype(bool)
xml_path = str(tmp_path / "a.xml")
downloaded.save(xml_path)
deserialized = Model.create_model(xml_path)
assert type(downloaded) == type(deserialized)
for attr in downloaded.parameters():
assert getattr(downloaded, attr) == getattr(deserialized, attr)
def test_segmentor_save(tmp_path):
downloaded = Model.create_model(
"hrnet-v2-c1-segmentation",
configuration={"reverse_input_channels": True, "labels": ["first", "second"]},
)
assert True == downloaded.get_model().get_rt_info(
["model_info", "embedded_processing"]
).astype(bool)
xml_path = str(tmp_path / "a.xml")
downloaded.save(xml_path)
deserialized = Model.create_model(xml_path)
assert type(downloaded) == type(deserialized)
for attr in downloaded.parameters():
assert getattr(downloaded, attr) == getattr(deserialized, attr)