Skip to content

Commit 4b7e1b3

Browse files
authored
Add test for node adaptation on Squeeze11 (#155)
1 parent 88a1504 commit 4b7e1b3

File tree

1 file changed

+55
-20
lines changed

1 file changed

+55
-20
lines changed

tests/test_adapt.py

+55-20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import onnx
66
import onnx.parser
7+
import onnxruntime as ort
78
import pytest
89

910
import spox.opset.ai.onnx.v18 as op18
@@ -71,32 +72,34 @@ def inline_old_identity_twice_graph(old_identity):
7172
return results(final=z).with_opset(("ai.onnx", 17))
7273

7374

74-
@pytest.fixture
75-
def old_squeeze_graph(old_squeeze):
76-
class Squeeze11(StandardNode):
77-
@dataclass
78-
class Attributes(BaseAttributes):
79-
axes: AttrInt64s
75+
class Squeeze11(StandardNode):
76+
@dataclass
77+
class Attributes(BaseAttributes):
78+
axes: AttrInt64s
79+
80+
@dataclass
81+
class Inputs(BaseInputs):
82+
data: Var
83+
84+
@dataclass
85+
class Outputs(BaseOutputs):
86+
squeezed: Var
8087

81-
@dataclass
82-
class Inputs(BaseInputs):
83-
data: Var
88+
op_type = OpType("Squeeze", "", 11)
8489

85-
@dataclass
86-
class Outputs(BaseOutputs):
87-
squeezed: Var
90+
attrs: Attributes
91+
inputs: Inputs
92+
outputs: Outputs
8893

89-
op_type = OpType("Squeeze", "", 11)
9094

91-
attrs: Attributes
92-
inputs: Inputs
93-
outputs: Outputs
95+
def squeeze11(_data: Var, _axes: Iterable[int]):
96+
return Squeeze11(
97+
Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data)
98+
).outputs.squeezed
9499

95-
def squeeze11(_data: Var, _axes: Iterable[int]):
96-
return Squeeze11(
97-
Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data)
98-
).outputs.squeezed
99100

101+
@pytest.fixture
102+
def old_squeeze_graph(old_squeeze):
100103
(data,) = arguments(
101104
data=Tensor(
102105
np.float32,
@@ -233,3 +236,35 @@ def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto):
233236
# Add another node to the model to trigger the adaption logic
234237
c = op18.identity(b)
235238
build({"a": a}, {"c": c})
239+
240+
241+
def test_if_adapatation_squeeze():
242+
cond = argument(Tensor(np.bool_, ()))
243+
b = argument(Tensor(np.float32, (1,)))
244+
squeezed = squeeze11(b, [0])
245+
out = op18.if_(
246+
cond,
247+
then_branch=lambda: [squeezed],
248+
else_branch=lambda: [squeeze11(b, [0])],
249+
)
250+
model = build({"b": b, "cond": cond}, {"out": out[0]})
251+
252+
# predict on model
253+
b = np.array([1.1], dtype=np.float32)
254+
cond = np.array(True, dtype=np.bool_)
255+
out = ort.InferenceSession(model.SerializeToString()).run(
256+
None, {"b": b, "cond": cond}
257+
)
258+
259+
260+
def test_if_adaptation_const():
261+
sq = op19.const(1.1453, dtype=np.float32)
262+
b = argument(Tensor(np.float32, ("N",)))
263+
cond = op18.equal(sq, b)
264+
out = op18.if_(cond, then_branch=lambda: [sq], else_branch=lambda: [sq])
265+
model = build({"b": b}, {"out": out[0]})
266+
assert model.domain == "" or model.domain == "ai.onnx"
267+
assert (
268+
model.opset_import[0].domain == "ai.onnx" or model.opset_import[0].domain == ""
269+
)
270+
assert model.opset_import[0].version > 11

0 commit comments

Comments
 (0)