Skip to content

Commit 0731c0a

Browse files
authored
Add opsets ai.onnx@21 and ai.onnx.ml@5 (#149)
1 parent 47b13ed commit 0731c0a

File tree

4 files changed

+3341
-0
lines changed

4 files changed

+3341
-0
lines changed

CHANGELOG.rst

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Change log
88
==========
99

10+
0.11.0 (2024-04-23)
11+
-------------------
12+
13+
**New feature**
14+
15+
- The opsets ``ai.onnx`` version 21 and ``ai.onnx.ml`` version 5 (released with ONNX 1.16) are now shipped with Spox.
16+
17+
1018
0.10.3 (2024-03-14)
1119
-------------------
1220

src/spox/opset/ai/onnx/ml/v5.py

+295
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# ruff: noqa: E741 -- Allow ambiguous variable name
2+
from dataclasses import dataclass
3+
from typing import (
4+
Iterable,
5+
Optional,
6+
)
7+
8+
import numpy as np
9+
10+
from spox._attributes import (
11+
AttrInt64,
12+
AttrInt64s,
13+
AttrTensor,
14+
)
15+
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs
16+
from spox._node import OpType
17+
from spox._standard import StandardNode
18+
from spox._var import Var
19+
from spox.opset.ai.onnx.ml.v4 import (
20+
_ArrayFeatureExtractor,
21+
_Binarizer,
22+
_CastMap,
23+
_CategoryMapper,
24+
_DictVectorizer,
25+
_FeatureVectorizer,
26+
_Imputer,
27+
_LabelEncoder,
28+
_LinearClassifier,
29+
_LinearRegressor,
30+
_Normalizer,
31+
_OneHotEncoder,
32+
_Scaler,
33+
_SVMClassifier,
34+
_SVMRegressor,
35+
_ZipMap,
36+
array_feature_extractor,
37+
binarizer,
38+
cast_map,
39+
category_mapper,
40+
dict_vectorizer,
41+
feature_vectorizer,
42+
imputer,
43+
label_encoder,
44+
linear_classifier,
45+
linear_regressor,
46+
normalizer,
47+
one_hot_encoder,
48+
scaler,
49+
svmclassifier,
50+
svmregressor,
51+
zip_map,
52+
)
53+
54+
55+
class _TreeEnsemble(StandardNode):
56+
@dataclass
57+
class Attributes(BaseAttributes):
58+
aggregate_function: AttrInt64
59+
leaf_targetids: AttrInt64s
60+
leaf_weights: AttrTensor
61+
membership_values: Optional[AttrTensor]
62+
n_targets: Optional[AttrInt64]
63+
nodes_falseleafs: AttrInt64s
64+
nodes_falsenodeids: AttrInt64s
65+
nodes_featureids: AttrInt64s
66+
nodes_hitrates: Optional[AttrTensor]
67+
nodes_missing_value_tracks_true: Optional[AttrInt64s]
68+
nodes_modes: AttrTensor
69+
nodes_splits: AttrTensor
70+
nodes_trueleafs: AttrInt64s
71+
nodes_truenodeids: AttrInt64s
72+
post_transform: AttrInt64
73+
tree_roots: AttrInt64s
74+
75+
@dataclass
76+
class Inputs(BaseInputs):
77+
X: Var
78+
79+
@dataclass
80+
class Outputs(BaseOutputs):
81+
Y: Var
82+
83+
op_type = OpType("TreeEnsemble", "ai.onnx.ml", 5)
84+
85+
attrs: Attributes
86+
inputs: Inputs
87+
outputs: Outputs
88+
89+
90+
def tree_ensemble(
91+
X: Var,
92+
*,
93+
aggregate_function: int = 1,
94+
leaf_targetids: Iterable[int],
95+
leaf_weights: np.ndarray,
96+
membership_values: Optional[np.ndarray] = None,
97+
n_targets: Optional[int] = None,
98+
nodes_falseleafs: Iterable[int],
99+
nodes_falsenodeids: Iterable[int],
100+
nodes_featureids: Iterable[int],
101+
nodes_hitrates: Optional[np.ndarray] = None,
102+
nodes_missing_value_tracks_true: Optional[Iterable[int]] = None,
103+
nodes_modes: np.ndarray,
104+
nodes_splits: np.ndarray,
105+
nodes_trueleafs: Iterable[int],
106+
nodes_truenodeids: Iterable[int],
107+
post_transform: int = 0,
108+
tree_roots: Iterable[int],
109+
) -> Var:
110+
r"""
111+
Tree Ensemble operator. Returns the regressed values for each input in a
112+
batch. Inputs have dimensions ``[N, F]`` where ``N`` is the input batch
113+
size and ``F`` is the number of input features. Outputs have dimensions
114+
``[N, num_targets]`` where ``N`` is the batch size and ``num_targets``
115+
is the number of targets, which is a configurable attribute.
116+
117+
::
118+
119+
The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves.
120+
The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position.
121+
122+
All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length.
123+
124+
This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes.
125+
The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented.
126+
The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class.
127+
To encode class labels, a `LabelEncoder` or `GatherND` operator may be used.
128+
129+
Parameters
130+
==========
131+
X
132+
Type T.
133+
Input of shape [Batch Size, Number of Features]
134+
aggregate_function
135+
Attribute.
136+
Defines how to aggregate leaf values within a target. One of 'AVERAGE'
137+
(0) 'SUM' (1) 'MIN' (2) 'MAX (3) defaults to 'SUM' (1)
138+
leaf_targetids
139+
Attribute.
140+
The index of the target that this leaf contributes to (this must be in
141+
range ``[0, n_targets)``).
142+
leaf_weights
143+
Attribute.
144+
The weight for each leaf.
145+
membership_values
146+
Attribute.
147+
Members to test membership of for each set membership node. List all of
148+
the members to test again in the order that the 'BRANCH_MEMBER' mode
149+
appears in ``node_modes``, delimited by ``NaN``\ s. Will have the same
150+
number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be
151+
omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes.
152+
n_targets
153+
Attribute.
154+
The total number of targets.
155+
nodes_falseleafs
156+
Attribute.
157+
1 if false branch is leaf for each node and 0 if an interior node. To
158+
represent a tree that is a leaf (only has one node), one can do so by
159+
having a single ``nodes_*`` entry with true and false branches
160+
referencing the same ``leaf_*`` entry
161+
nodes_falsenodeids
162+
Attribute.
163+
If ``nodes_falseleafs`` is false at an entry, this represents the
164+
position of the false branch node. This position can be used to index
165+
into a ``nodes_*`` entry. If ``nodes_falseleafs`` is false, it is an
166+
index into the leaf\_\* attributes.
167+
nodes_featureids
168+
Attribute.
169+
Feature id for each node.
170+
nodes_hitrates
171+
Attribute.
172+
Popularity of each node, used for performance and may be omitted.
173+
nodes_missing_value_tracks_true
174+
Attribute.
175+
For each node, define whether to follow the true branch (if attribute
176+
value is 1) or false branch (if attribute value is 0) in the presence of
177+
a NaN input feature. This attribute may be left undefined and the
178+
default value is false (0) for all nodes.
179+
nodes_modes
180+
Attribute.
181+
The comparison operation performed by the node. This is encoded as an
182+
enumeration of 0 ('BRANCH_LEQ'), 1 ('BRANCH_LT'), 2 ('BRANCH_GTE'), 3
183+
('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and 6
184+
('BRANCH_MEMBER'). Note this is a tensor of type uint8.
185+
nodes_splits
186+
Attribute.
187+
Thresholds to do the splitting on for each node with mode that is not
188+
'BRANCH_MEMBER'.
189+
nodes_trueleafs
190+
Attribute.
191+
1 if true branch is leaf for each node and 0 an interior node. To
192+
represent a tree that is a leaf (only has one node), one can do so by
193+
having a single ``nodes_*`` entry with true and false branches
194+
referencing the same ``leaf_*`` entry
195+
nodes_truenodeids
196+
Attribute.
197+
If ``nodes_trueleafs`` is false at an entry, this represents the
198+
position of the true branch node. This position can be used to index
199+
into a ``nodes_*`` entry. If ``nodes_trueleafs`` is false, it is an
200+
index into the leaf\_\* attributes.
201+
post_transform
202+
Attribute.
203+
Indicates the transform to apply to the score. One of 'NONE' (0),
204+
'SOFTMAX' (1), 'LOGISTIC' (2), 'SOFTMAX_ZERO' (3) or 'PROBIT' (4),
205+
defaults to 'NONE' (0)
206+
tree_roots
207+
Attribute.
208+
Index into ``nodes_*`` for the root of each tree. The tree structure is
209+
derived from the branching of each node.
210+
211+
Returns
212+
=======
213+
Y : Var
214+
Type T.
215+
Output of shape [Batch Size, Number of targets]
216+
217+
Notes
218+
=====
219+
Signature: ``ai.onnx.ml@5::TreeEnsemble``.
220+
221+
Type constraints:
222+
- T: `tensor(double)`, `tensor(float)`, `tensor(float16)`
223+
"""
224+
return _TreeEnsemble(
225+
_TreeEnsemble.Attributes(
226+
aggregate_function=AttrInt64(aggregate_function, name="aggregate_function"),
227+
leaf_targetids=AttrInt64s(leaf_targetids, name="leaf_targetids"),
228+
leaf_weights=AttrTensor(leaf_weights, name="leaf_weights"),
229+
membership_values=AttrTensor.maybe(
230+
membership_values, name="membership_values"
231+
),
232+
n_targets=AttrInt64.maybe(n_targets, name="n_targets"),
233+
nodes_falseleafs=AttrInt64s(nodes_falseleafs, name="nodes_falseleafs"),
234+
nodes_falsenodeids=AttrInt64s(
235+
nodes_falsenodeids, name="nodes_falsenodeids"
236+
),
237+
nodes_featureids=AttrInt64s(nodes_featureids, name="nodes_featureids"),
238+
nodes_hitrates=AttrTensor.maybe(nodes_hitrates, name="nodes_hitrates"),
239+
nodes_missing_value_tracks_true=AttrInt64s.maybe(
240+
nodes_missing_value_tracks_true, name="nodes_missing_value_tracks_true"
241+
),
242+
nodes_modes=AttrTensor(nodes_modes, name="nodes_modes"),
243+
nodes_splits=AttrTensor(nodes_splits, name="nodes_splits"),
244+
nodes_trueleafs=AttrInt64s(nodes_trueleafs, name="nodes_trueleafs"),
245+
nodes_truenodeids=AttrInt64s(nodes_truenodeids, name="nodes_truenodeids"),
246+
post_transform=AttrInt64(post_transform, name="post_transform"),
247+
tree_roots=AttrInt64s(tree_roots, name="tree_roots"),
248+
),
249+
_TreeEnsemble.Inputs(
250+
X=X,
251+
),
252+
).outputs.Y
253+
254+
255+
_OPERATORS = {
256+
"ArrayFeatureExtractor": _ArrayFeatureExtractor,
257+
"Binarizer": _Binarizer,
258+
"CastMap": _CastMap,
259+
"CategoryMapper": _CategoryMapper,
260+
"DictVectorizer": _DictVectorizer,
261+
"FeatureVectorizer": _FeatureVectorizer,
262+
"Imputer": _Imputer,
263+
"LabelEncoder": _LabelEncoder,
264+
"LinearClassifier": _LinearClassifier,
265+
"LinearRegressor": _LinearRegressor,
266+
"Normalizer": _Normalizer,
267+
"OneHotEncoder": _OneHotEncoder,
268+
"SVMClassifier": _SVMClassifier,
269+
"SVMRegressor": _SVMRegressor,
270+
"Scaler": _Scaler,
271+
"TreeEnsemble": _TreeEnsemble,
272+
"ZipMap": _ZipMap,
273+
}
274+
275+
_CONSTRUCTORS = {
276+
"ArrayFeatureExtractor": array_feature_extractor,
277+
"Binarizer": binarizer,
278+
"CastMap": cast_map,
279+
"CategoryMapper": category_mapper,
280+
"DictVectorizer": dict_vectorizer,
281+
"FeatureVectorizer": feature_vectorizer,
282+
"Imputer": imputer,
283+
"LabelEncoder": label_encoder,
284+
"LinearClassifier": linear_classifier,
285+
"LinearRegressor": linear_regressor,
286+
"Normalizer": normalizer,
287+
"OneHotEncoder": one_hot_encoder,
288+
"SVMClassifier": svmclassifier,
289+
"SVMRegressor": svmregressor,
290+
"Scaler": scaler,
291+
"TreeEnsemble": tree_ensemble,
292+
"ZipMap": zip_map,
293+
}
294+
295+
__all__ = [fun.__name__ for fun in _CONSTRUCTORS.values()]

0 commit comments

Comments
 (0)