Skip to content

Commit efbe343

Browse files
authored
Add option to drop unused inputs during build (#152)
1 parent 56ce4df commit efbe343

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Change log
88
==========
99

10+
0.12.0 (2024-05-16)
11+
-------------------
12+
13+
**New feature**
14+
15+
- The :func:`spox.build` function gained the ``drop_unused_inputs`` argument.
16+
17+
1018
0.11.0 (2024-04-23)
1119
-------------------
1220

pyproject.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,8 @@ check_untyped_defs = true
7272

7373
[tool.pytest.ini_options]
7474
# This will be pytest's future default.
75-
addopts = "--import-mode=importlib"
75+
addopts = "--import-mode=importlib"
76+
filterwarnings = [
77+
# Protobuf warning seen when running the test suite
78+
"ignore:.*Type google.protobuf.pyext.*:DeprecationWarning:.*",
79+
]

src/spox/_public.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def _temporary_renames(**kwargs: Var):
5454
arg._rename(name)
5555

5656

57-
def build(inputs: Dict[str, Var], outputs: Dict[str, Var]) -> onnx.ModelProto:
57+
def build(
58+
inputs: Dict[str, Var], outputs: Dict[str, Var], *, drop_unused_inputs=False
59+
) -> onnx.ModelProto:
5860
"""
5961
Builds an ONNX Model with given model inputs and outputs.
6062
@@ -71,6 +73,11 @@ def build(inputs: Dict[str, Var], outputs: Dict[str, Var]) -> onnx.ModelProto:
7173
Model outputs. Keys are names, values may be any ``Var``.
7274
Building will resolve what nodes were used in the construction
7375
of output variables.
76+
drop_unused_inputs
77+
If ``False``, only inputs that are needed for the computation
78+
of the ``outputs`` will appear as inputs of the returned
79+
``ModelProto``. Otherwise, all inputs are required by the
80+
returned object (default).
7481
7582
Returns
7683
-------
@@ -81,6 +88,11 @@ def build(inputs: Dict[str, Var], outputs: Dict[str, Var]) -> onnx.ModelProto:
8188
the newest one. The minimum ``ai.onnx`` version is set to 14 to
8289
avoid tooling issues with legacy versions.
8390
91+
Raises
92+
------
93+
KeyError
94+
If the ``outputs`` cannot be build from the given ``inputs``.
95+
8496
Examples
8597
--------
8698
>>> import numpy as np
@@ -113,8 +125,15 @@ def build(inputs: Dict[str, Var], outputs: Dict[str, Var]) -> onnx.ModelProto:
113125

114126
with _temporary_renames(**inputs):
115127
graph = results(**outputs)
116-
graph = graph.with_arguments(*inputs.values())
117-
return graph.to_onnx_model()
128+
if not drop_unused_inputs:
129+
graph = graph.with_arguments(*inputs.values())
130+
model_proto = graph.to_onnx_model()
131+
132+
# Validate that no further inputs were required.
133+
if any(inp.name not in inputs for inp in model_proto.graph.input):
134+
raise KeyError("Model requires additional inputs not provided in 'inputs'.")
135+
136+
return model_proto
118137

119138

120139
class _InlineCall(Protocol):

tests/test_public.py

+19
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,22 @@ def test_shallow_deepcopy_var_raises():
100100
a = argument(Tensor(float, ()))
101101
with pytest.raises(ValueError):
102102
deepcopy(a)
103+
104+
105+
def test_build_drop_unused_arguments():
106+
a = argument(Tensor(float, ()))
107+
unused = argument(Tensor(float, ()))
108+
c = op.add(a, a)
109+
model_proto = build({"a": a, "unused": unused}, {"c": c}, drop_unused_inputs=True)
110+
actual_inputs = [el.name for el in model_proto.graph.input]
111+
112+
assert actual_inputs == ["a"]
113+
114+
115+
@pytest.mark.parametrize("drop_unused", [True, False])
116+
def test_raise_missing_input(drop_unused):
117+
a = argument(Tensor(float, ()))
118+
b = argument(Tensor(float, ()))
119+
120+
with pytest.raises(KeyError):
121+
build({"a": a}, {"c": op.add(a, b)}, drop_unused_inputs=drop_unused)

0 commit comments

Comments
 (0)