|
4 | 4 | import numpy as np
|
5 | 5 | import onnx
|
6 | 6 | import onnx.parser
|
| 7 | +import onnxruntime as ort |
7 | 8 | import pytest
|
8 | 9 |
|
9 | 10 | import spox.opset.ai.onnx.v18 as op18
|
@@ -71,32 +72,34 @@ def inline_old_identity_twice_graph(old_identity):
|
71 | 72 | return results(final=z).with_opset(("ai.onnx", 17))
|
72 | 73 |
|
73 | 74 |
|
74 |
| -@pytest.fixture |
75 |
| -def old_squeeze_graph(old_squeeze): |
76 |
| - class Squeeze11(StandardNode): |
77 |
| - @dataclass |
78 |
| - class Attributes(BaseAttributes): |
79 |
| - axes: AttrInt64s |
| 75 | +class Squeeze11(StandardNode): |
| 76 | + @dataclass |
| 77 | + class Attributes(BaseAttributes): |
| 78 | + axes: AttrInt64s |
| 79 | + |
| 80 | + @dataclass |
| 81 | + class Inputs(BaseInputs): |
| 82 | + data: Var |
| 83 | + |
| 84 | + @dataclass |
| 85 | + class Outputs(BaseOutputs): |
| 86 | + squeezed: Var |
80 | 87 |
|
81 |
| - @dataclass |
82 |
| - class Inputs(BaseInputs): |
83 |
| - data: Var |
| 88 | + op_type = OpType("Squeeze", "", 11) |
84 | 89 |
|
85 |
| - @dataclass |
86 |
| - class Outputs(BaseOutputs): |
87 |
| - squeezed: Var |
| 90 | + attrs: Attributes |
| 91 | + inputs: Inputs |
| 92 | + outputs: Outputs |
88 | 93 |
|
89 |
| - op_type = OpType("Squeeze", "", 11) |
90 | 94 |
|
91 |
| - attrs: Attributes |
92 |
| - inputs: Inputs |
93 |
| - outputs: Outputs |
| 95 | +def squeeze11(_data: Var, _axes: Iterable[int]): |
| 96 | + return Squeeze11( |
| 97 | + Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data) |
| 98 | + ).outputs.squeezed |
94 | 99 |
|
95 |
| - def squeeze11(_data: Var, _axes: Iterable[int]): |
96 |
| - return Squeeze11( |
97 |
| - Squeeze11.Attributes(AttrInt64s(_axes, "axes")), Squeeze11.Inputs(_data) |
98 |
| - ).outputs.squeezed |
99 | 100 |
|
| 101 | +@pytest.fixture |
| 102 | +def old_squeeze_graph(old_squeeze): |
100 | 103 | (data,) = arguments(
|
101 | 104 | data=Tensor(
|
102 | 105 | np.float32,
|
@@ -233,3 +236,35 @@ def test_inline_model_custom_node_nested(old_squeeze: onnx.ModelProto):
|
233 | 236 | # Add another node to the model to trigger the adaption logic
|
234 | 237 | c = op18.identity(b)
|
235 | 238 | build({"a": a}, {"c": c})
|
| 239 | + |
| 240 | + |
| 241 | +def test_if_adapatation_squeeze(): |
| 242 | + cond = argument(Tensor(np.bool_, ())) |
| 243 | + b = argument(Tensor(np.float32, (1,))) |
| 244 | + squeezed = squeeze11(b, [0]) |
| 245 | + out = op18.if_( |
| 246 | + cond, |
| 247 | + then_branch=lambda: [squeezed], |
| 248 | + else_branch=lambda: [squeeze11(b, [0])], |
| 249 | + ) |
| 250 | + model = build({"b": b, "cond": cond}, {"out": out[0]}) |
| 251 | + |
| 252 | + # predict on model |
| 253 | + b = np.array([1.1], dtype=np.float32) |
| 254 | + cond = np.array(True, dtype=np.bool_) |
| 255 | + out = ort.InferenceSession(model.SerializeToString()).run( |
| 256 | + None, {"b": b, "cond": cond} |
| 257 | + ) |
| 258 | + |
| 259 | + |
| 260 | +def test_if_adaptation_const(): |
| 261 | + sq = op19.const(1.1453, dtype=np.float32) |
| 262 | + b = argument(Tensor(np.float32, ("N",))) |
| 263 | + cond = op18.equal(sq, b) |
| 264 | + out = op18.if_(cond, then_branch=lambda: [sq], else_branch=lambda: [sq]) |
| 265 | + model = build({"b": b}, {"out": out[0]}) |
| 266 | + assert model.domain == "" or model.domain == "ai.onnx" |
| 267 | + assert ( |
| 268 | + model.opset_import[0].domain == "ai.onnx" or model.opset_import[0].domain == "" |
| 269 | + ) |
| 270 | + assert model.opset_import[0].version > 11 |
0 commit comments