Skip to content

Commit d6c9b85

Browse files
authored
Imporve ONNX 1.14 compatibility (#134)
1 parent 8cb7271 commit d6c9b85

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Change log
88
==========
99

10+
0.10.1 (2023-02-07)
11+
-------------------
12+
13+
**Other changes**
14+
15+
- Spox's compatibility with older versions of onnx has been improved.
16+
17+
1018
0.10.0 (2023-02-02)
1119
-------------------
1220

src/spox/_attributes.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
import numpy as np
66
import numpy.typing as npt
7+
import onnx
78
from onnx import AttributeProto
89
from onnx.helper import (
910
make_attribute,
1011
make_optional_type_proto,
1112
make_sequence_type_proto,
1213
make_tensor_type_proto,
1314
)
15+
from packaging import version
1416

1517
from spox import _type_system
1618
from spox._utils import dtype_to_tensor_type, from_array
@@ -210,9 +212,13 @@ def maybe(
210212
return cls(tuple(value), name) if value is not None else None
211213

212214
def _to_onnx_deref(self) -> AttributeProto:
213-
return make_attribute(
214-
self._name, self.value, attr_type=self._attribute_proto_type
215-
)
215+
# 1.15 introduced attr_type which provides much better performance
216+
if version.parse(onnx.__version__) >= version.parse("1.15"):
217+
return make_attribute(
218+
self._name, self.value, attr_type=self._attribute_proto_type
219+
)
220+
else:
221+
return make_attribute(self._name, self.value)
216222

217223

218224
class AttrFloat32s(_AttrIterable[float]):

tests/type_inference/test_scaler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_scaler_inference():
1616
def test_scaler_inference_fails_mismatched_lengths():
1717
(x,) = arguments(x=Tensor(np.float64, ("N", 3)))
1818
with pytest.raises(InferenceError):
19-
op_ml.scaler(x, offset=[0.0, 0.1], scale=[1])
19+
op_ml.scaler(x, offset=[0.0, 0.1], scale=[1.0])

0 commit comments

Comments
 (0)