Skip to content

Commit 42e0aa3

Browse files
Do not prefix unset inputs when inlining (#160)
Co-authored-by: Christian Bourjau <christian.bourjau@quantco.com>
1 parent 65e91aa commit 42e0aa3

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

CHANGELOG.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
Change log
88
==========
99

10-
0.12.1 (unreleased)
10+
0.12.1 (2024-06-18)
1111
-------------------
1212

13+
**Bug fix**
14+
15+
- Unset optional inputs are no longer erroneously prefixed by :func:`~spox.inline`.
16+
17+
1318
**Other changes**
1419

1520
- The node-naming algorithm now has constant rather than quadratic time complexity.

src/spox/_inline.py

+2
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def to_onnx(
154154
inner_node_renames: Dict[str, str] = {}
155155

156156
def reserve_prefixed(name: str) -> str:
157+
if not name:
158+
return name
157159
return scope.var.reserve(
158160
scope.var.maybe_enum(f"{scope.node[self]}__{name}")
159161
)

tests/test_inline.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
import pytest
77

88
import spox.opset.ai.onnx.v17 as op
9+
from spox import Tensor, Var, argument, build, inline
910
from spox._graph import arguments, results
1011
from spox._inline import rename_in_graph
11-
from spox._public import inline
12-
from spox._type_system import Tensor
1312
from spox._utils import from_array
14-
from spox._var import Var
1513

1614

1715
@pytest.fixture
@@ -342,3 +340,21 @@ def example_rename(n: str) -> str:
342340
_duplicate_subgraphs_to_list(relu_proto.graph), example_rename
343341
)
344342
assert rename_then_duplicate.node == duplicate_then_rename.node
343+
344+
345+
def test_subgraph_with_nodes_with_optional_inputs():
346+
"""Unset optional inputs must not be prefixed by `inline`."""
347+
348+
def inline_model() -> onnx.ModelProto:
349+
a = argument(Tensor(numpy.float64, ("N",)))
350+
return build({"a": a}, {"b": op.clip(a, None, op.const(1.0, numpy.float64))})
351+
352+
foo = argument(Tensor(numpy.float64, ("N",)))
353+
(bar,) = inline(inline_model())(foo).values()
354+
355+
model_proto = build({"foo": foo}, {"bar": bar})
356+
357+
(clip_node,) = (n for n in model_proto.graph.node if n.op_type == "Clip")
358+
assert len(clip_node.input) == 3
359+
assert clip_node.input[1] == ""
360+
assert clip_node.input[2] != ""

0 commit comments

Comments
 (0)