Skip to content

Commit 2886686

Browse files
Caching for attributes that allow it (#123)
* Did I just cache? * Bind name to attributes upon construction (#126) * At least tests pass, and that's something * Remove assert * Contain caching to attributes class * No need to check for AttrGraph and AttrType apparently - they are cacheable? * Pass name in a failing test * appease pch * Update CHANGELOG.rst
1 parent b7b9ca9 commit 2886686

18 files changed

+740
-568
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Change log
1717
**Other changes**
1818

1919
- The validation of Node attributes has been improved and more consistent exceptions are raised if needed.
20+
- ONNX node attributes are now computed only once and then cached so that the values are reused for validation and building the model.
2021

2122

2223
0.9.3 (2023-10-23)

src/spox/_attributes.py

+54-35
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,25 @@
2323

2424
class Attr(ABC, Generic[T]):
2525
_value: Union[T, "_Ref[T]"]
26+
_name: str
27+
_cached_onnx: Optional[AttributeProto]
2628

27-
def __init__(self, value: Union[T, "_Ref[T]"]):
29+
def __init__(self, value: Union[T, "_Ref[T]"], name: str):
2830
self._value = value
31+
self._name = name
32+
self._cached_onnx = None
33+
2934
self._validate()
3035

36+
def deref(self) -> "Attr":
37+
if isinstance(self._value, _Ref):
38+
return type(self)(self.value, self._name)
39+
else:
40+
return self
41+
3142
@classmethod
32-
def maybe(cls: Type[AttrT], value: Optional[T]) -> Optional[AttrT]:
33-
return cls(value) if value is not None else None
43+
def maybe(cls: Type[AttrT], value: Optional[T], name: str) -> Optional[AttrT]:
44+
return cls(value, name) if value is not None else None
3445

3546
@property
3647
def value(self) -> T:
@@ -41,7 +52,7 @@ def value(self) -> T:
4152

4253
def _validate(self):
4354
try:
44-
type_in_onnx = self._to_onnx_deref("dummy").type
55+
type_in_onnx = self._to_onnx().type
4556
except Exception as e:
4657
# Likely an error from within onnx/protobuf, such as:
4758
# 1) AttributeError: 'int' object has no attribute 'encode'
@@ -52,18 +63,19 @@ def _validate(self):
5263
if type_in_onnx != self._attribute_proto_type:
5364
raise self._get_pretty_type_exception()
5465

55-
def _to_onnx(self, key: str) -> AttributeProto:
66+
def _to_onnx(self) -> AttributeProto:
5667
if isinstance(self._value, _Ref):
57-
return self._value._to_onnx(key)
58-
return self._to_onnx_deref(key)
68+
return self._value._to_onnx()
69+
if self._cached_onnx is None:
70+
self._cached_onnx = self._to_onnx_deref()
71+
return self._cached_onnx
5972

6073
@property
6174
@abc.abstractmethod
6275
def _attribute_proto_type(self) -> int:
6376
raise NotImplementedError()
6477

65-
@abc.abstractmethod
66-
def _to_onnx_deref(self, key: str) -> AttributeProto:
78+
def _to_onnx_deref(self) -> AttributeProto:
6779
"""Conversion method for the dereferenced case."""
6880
raise NotImplementedError()
6981

@@ -87,57 +99,58 @@ class _Ref(Generic[T]):
8799

88100
_concrete: Attr[T]
89101

90-
def __init__(self, concrete: Attr[T], outer_name: str):
102+
def __init__(self, concrete: Attr[T], outer_name: str, name: str):
91103
self._concrete = concrete
92104
self._outer_name = outer_name
105+
self._name = name
93106

94107
def copy(self) -> "_Ref[T]":
95108
return self
96109

97-
def _to_onnx(self, key: str) -> AttributeProto:
98-
parent_type = self._concrete._to_onnx(key).type
110+
def _to_onnx(self) -> AttributeProto:
111+
parent_type = self._concrete._to_onnx().type
99112
return AttributeProto(
100-
name=key, ref_attr_name=self._outer_name, type=parent_type
113+
name=self._name, ref_attr_name=self._outer_name, type=parent_type
101114
)
102115

103116

104117
class AttrFloat32(Attr[float]):
105118
_attribute_proto_type = AttributeProto.FLOAT
106119

107-
def _to_onnx_deref(self, key: str) -> AttributeProto:
120+
def _to_onnx_deref(self) -> AttributeProto:
108121
if isinstance(self.value, int):
109-
return make_attribute(key, float(self.value))
110-
return make_attribute(key, self.value)
122+
return make_attribute(self._name, float(self.value))
123+
return make_attribute(self._name, self.value)
111124

112125

113126
class AttrInt64(Attr[int]):
114127
_attribute_proto_type = AttributeProto.INT
115128

116-
def _to_onnx_deref(self, key: str) -> AttributeProto:
117-
return make_attribute(key, self.value)
129+
def _to_onnx_deref(self) -> AttributeProto:
130+
return make_attribute(self._name, self.value)
118131

119132

120133
class AttrString(Attr[str]):
121134
_attribute_proto_type = AttributeProto.STRING
122135

123-
def _to_onnx_deref(self, key: str) -> AttributeProto:
124-
return make_attribute(key, self.value)
136+
def _to_onnx_deref(self) -> AttributeProto:
137+
return make_attribute(self._name, self.value)
125138

126139

127140
class AttrTensor(Attr[np.ndarray]):
128141
_attribute_proto_type = AttributeProto.TENSOR
129142

130-
def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]]):
131-
super().__init__(value.copy())
143+
def __init__(self, value: Union[np.ndarray, _Ref[np.ndarray]], name: str):
144+
super().__init__(value.copy(), name)
132145

133-
def _to_onnx_deref(self, key: str) -> AttributeProto:
134-
return make_attribute(key, from_array(self.value))
146+
def _to_onnx_deref(self) -> AttributeProto:
147+
return make_attribute(self._name, from_array(self.value))
135148

136149

137150
class AttrType(Attr[_type_system.Type]):
138151
_attribute_proto_type = AttributeProto.TYPE_PROTO
139152

140-
def _to_onnx_deref(self, key: str) -> AttributeProto:
153+
def _to_onnx_deref(self) -> AttributeProto:
141154
value = self.value # for type-checkers with limited property support
142155
if isinstance(value, _type_system.Tensor):
143156
type_proto = make_tensor_type_proto(
@@ -150,7 +163,7 @@ def _to_onnx_deref(self, key: str) -> AttributeProto:
150163
type_proto = make_optional_type_proto(value.elem_type._to_onnx())
151164
else:
152165
raise NotImplementedError()
153-
return make_attribute(key, type_proto)
166+
return make_attribute(self._name, type_proto)
154167

155168

156169
class AttrDtype(Attr[npt.DTypeLike]):
@@ -161,8 +174,8 @@ class AttrDtype(Attr[npt.DTypeLike]):
161174
def _validate(self):
162175
dtype_to_tensor_type(self.value)
163176

164-
def _to_onnx_deref(self, key: str) -> AttributeProto:
165-
return make_attribute(key, dtype_to_tensor_type(self.value))
177+
def _to_onnx_deref(self) -> AttributeProto:
178+
return make_attribute(self._name, dtype_to_tensor_type(self.value))
166179

167180

168181
class AttrGraph(Attr[Any]):
@@ -176,24 +189,30 @@ def _validate(self):
176189
f"Expected value of type `spox.graph.Graph found `{type(self.value)}`"
177190
)
178191

179-
def _to_onnx_deref(self, key: str) -> AttributeProto:
192+
def _to_onnx_deref(self) -> AttributeProto:
180193
raise TypeError(
181194
"Graph attributes must be built using the `build_subgraph` callback in `Node.to_onnx`."
182195
)
183196

184197

185198
class _AttrIterable(Attr[Tuple[S, ...]], ABC):
186-
def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]]):
187-
super().__init__(value if isinstance(value, _Ref) else tuple(value))
199+
def __init__(self, value: Union[Iterable[S], _Ref[Tuple[S, ...]]], name: str):
200+
super().__init__(
201+
value=value if isinstance(value, _Ref) else tuple(value), name=name
202+
)
188203

189204
@classmethod
190205
def maybe(
191-
cls: Type[AttrIterableT], value: Optional[Iterable[S]]
206+
cls: Type[AttrIterableT],
207+
value: Optional[Iterable[S]],
208+
name: str,
192209
) -> Optional[AttrIterableT]:
193-
return cls(tuple(value)) if value is not None else None
210+
return cls(tuple(value), name) if value is not None else None
194211

195-
def _to_onnx_deref(self, key: str) -> AttributeProto:
196-
return make_attribute(key, self.value, attr_type=self._attribute_proto_type)
212+
def _to_onnx_deref(self) -> AttributeProto:
213+
return make_attribute(
214+
self._name, self.value, attr_type=self._attribute_proto_type
215+
)
197216

198217

199218
class AttrFloat32s(_AttrIterable[float]):

src/spox/_function.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Function(_InternalNode):
3939
"""
4040

4141
func_args: Dict[str, Var]
42-
func_attrs: Dict[str, _attributes._Ref]
42+
func_attrs: Dict[str, _attributes.Attr]
4343
func_inputs: BaseInputs
4444
func_outputs: BaseOutputs
4545
func_graph: "_graph.Graph"
@@ -64,14 +64,13 @@ def infer_output_types(self) -> Dict[str, Type]:
6464
**{name: var.type for name, var in self.inputs.get_vars().items()}
6565
)
6666

67-
func_attrs = {}
67+
self.func_attrs = {}
6868
for name, attr in self.attrs.get_fields().items():
6969
if attr is None:
7070
raise TypeError(
7171
f"Function attributes is not optional, but {name} is None."
7272
)
73-
func_attrs[name] = _attributes._Ref(concrete=attr, outer_name=name)
74-
self.func_attrs = func_attrs
73+
self.func_attrs[name] = attr
7574

7675
self.func_inputs = self.Inputs(**self.func_args) # type: ignore
7776
self.func_outputs = self.constructor(self.func_attrs, self.func_inputs)

src/spox/_graph.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,23 @@ def arguments_dict(**kwargs: Optional[Union[Type, numpy.ndarray]]) -> Dict[str,
3737
"""
3838
result = {}
3939
for name, info in kwargs.items():
40-
attr_name = AttrString(name)
40+
attr_name = AttrString(value=name, name="dummy")
4141
if isinstance(info, Type):
4242
result[name] = Argument(
43-
Argument.Attributes(name=attr_name, type=AttrType(info), default=None),
43+
Argument.Attributes(
44+
name=attr_name,
45+
type=AttrType(value=info, name="dummy"),
46+
default=None,
47+
),
4448
BaseInputs(),
4549
).outputs.arg
4650
elif isinstance(info, numpy.ndarray):
4751
ty = Tensor(info.dtype, info.shape)
4852
result[name] = Argument(
4953
Argument.Attributes(
50-
name=attr_name, type=AttrType(ty), default=AttrTensor(info)
54+
name=attr_name,
55+
type=AttrType(value=ty, name="dummy"),
56+
default=AttrTensor(value=info, name="dummy"),
5157
),
5258
BaseInputs(),
5359
).outputs.arg
@@ -101,7 +107,7 @@ def initializer(arr: numpy.ndarray) -> Var:
101107
Var which is always equal to the respective value provided by `arr`.
102108
"""
103109
return _Initializer(
104-
_Initializer.Attributes(value=AttrTensor(arr)),
110+
_Initializer.Attributes(value=AttrTensor(value=arr, name="dummy")),
105111
BaseInputs(),
106112
).outputs.arg
107113

src/spox/_node.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def to_onnx(
373373
subgraph = build_subgraph(self, key, attr.value)
374374
attr_proto = onnx.helper.make_attribute(key, subgraph)
375375
else:
376-
attr_proto = attr._to_onnx(key)
376+
attr_proto = attr._to_onnx()
377377
node_proto.attribute.append(attr_proto)
378378

379379
return [node_proto]

src/spox/_public.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def argument(typ: Type) -> Var:
3333
a model input to build a graph.
3434
"""
3535
return _internal_op.Argument(
36-
_internal_op.Argument.Attributes(type=AttrType(typ), default=None)
36+
_internal_op.Argument.Attributes(type=AttrType(typ, "dummy"), default=None)
3737
).outputs.arg
3838

3939

src/spox/_standard.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,9 @@ def to_singleton_onnx_model(
6565
# We inject the evaluated attribute values here and then substitute back
6666
self_attrs = self.attrs
6767
try:
68-
# Get exact attribute values to run inference (as
69-
# otherwise refs aren't handled properly).
68+
current_fields = self_attrs.get_fields().items()
7069
self.attrs = self.Attributes(
71-
**{
72-
k: type(v)(v.value) if v is not None else v
73-
for k, v in self.attrs.get_fields().items()
74-
}
70+
**{k: v.deref() if v is not None else None for k, v in current_fields}
7571
)
7672
node_proto: onnx.NodeProto
7773
# Subgraphs are not fully built for possibly significant performance gains.

0 commit comments

Comments
 (0)