|
| 1 | +# ruff: noqa: E741 -- Allow ambiguous variable name |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import ( |
| 4 | + Iterable, |
| 5 | + Optional, |
| 6 | +) |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +from spox._attributes import ( |
| 11 | + AttrInt64, |
| 12 | + AttrInt64s, |
| 13 | + AttrTensor, |
| 14 | +) |
| 15 | +from spox._fields import BaseAttributes, BaseInputs, BaseOutputs |
| 16 | +from spox._node import OpType |
| 17 | +from spox._standard import StandardNode |
| 18 | +from spox._var import Var |
| 19 | +from spox.opset.ai.onnx.ml.v4 import ( |
| 20 | + _ArrayFeatureExtractor, |
| 21 | + _Binarizer, |
| 22 | + _CastMap, |
| 23 | + _CategoryMapper, |
| 24 | + _DictVectorizer, |
| 25 | + _FeatureVectorizer, |
| 26 | + _Imputer, |
| 27 | + _LabelEncoder, |
| 28 | + _LinearClassifier, |
| 29 | + _LinearRegressor, |
| 30 | + _Normalizer, |
| 31 | + _OneHotEncoder, |
| 32 | + _Scaler, |
| 33 | + _SVMClassifier, |
| 34 | + _SVMRegressor, |
| 35 | + _ZipMap, |
| 36 | + array_feature_extractor, |
| 37 | + binarizer, |
| 38 | + cast_map, |
| 39 | + category_mapper, |
| 40 | + dict_vectorizer, |
| 41 | + feature_vectorizer, |
| 42 | + imputer, |
| 43 | + label_encoder, |
| 44 | + linear_classifier, |
| 45 | + linear_regressor, |
| 46 | + normalizer, |
| 47 | + one_hot_encoder, |
| 48 | + scaler, |
| 49 | + svmclassifier, |
| 50 | + svmregressor, |
| 51 | + zip_map, |
| 52 | +) |
| 53 | + |
| 54 | + |
| 55 | +class _TreeEnsemble(StandardNode): |
| 56 | + @dataclass |
| 57 | + class Attributes(BaseAttributes): |
| 58 | + aggregate_function: AttrInt64 |
| 59 | + leaf_targetids: AttrInt64s |
| 60 | + leaf_weights: AttrTensor |
| 61 | + membership_values: Optional[AttrTensor] |
| 62 | + n_targets: Optional[AttrInt64] |
| 63 | + nodes_falseleafs: AttrInt64s |
| 64 | + nodes_falsenodeids: AttrInt64s |
| 65 | + nodes_featureids: AttrInt64s |
| 66 | + nodes_hitrates: Optional[AttrTensor] |
| 67 | + nodes_missing_value_tracks_true: Optional[AttrInt64s] |
| 68 | + nodes_modes: AttrTensor |
| 69 | + nodes_splits: AttrTensor |
| 70 | + nodes_trueleafs: AttrInt64s |
| 71 | + nodes_truenodeids: AttrInt64s |
| 72 | + post_transform: AttrInt64 |
| 73 | + tree_roots: AttrInt64s |
| 74 | + |
| 75 | + @dataclass |
| 76 | + class Inputs(BaseInputs): |
| 77 | + X: Var |
| 78 | + |
| 79 | + @dataclass |
| 80 | + class Outputs(BaseOutputs): |
| 81 | + Y: Var |
| 82 | + |
| 83 | + op_type = OpType("TreeEnsemble", "ai.onnx.ml", 5) |
| 84 | + |
| 85 | + attrs: Attributes |
| 86 | + inputs: Inputs |
| 87 | + outputs: Outputs |
| 88 | + |
| 89 | + |
| 90 | +def tree_ensemble( |
| 91 | + X: Var, |
| 92 | + *, |
| 93 | + aggregate_function: int = 1, |
| 94 | + leaf_targetids: Iterable[int], |
| 95 | + leaf_weights: np.ndarray, |
| 96 | + membership_values: Optional[np.ndarray] = None, |
| 97 | + n_targets: Optional[int] = None, |
| 98 | + nodes_falseleafs: Iterable[int], |
| 99 | + nodes_falsenodeids: Iterable[int], |
| 100 | + nodes_featureids: Iterable[int], |
| 101 | + nodes_hitrates: Optional[np.ndarray] = None, |
| 102 | + nodes_missing_value_tracks_true: Optional[Iterable[int]] = None, |
| 103 | + nodes_modes: np.ndarray, |
| 104 | + nodes_splits: np.ndarray, |
| 105 | + nodes_trueleafs: Iterable[int], |
| 106 | + nodes_truenodeids: Iterable[int], |
| 107 | + post_transform: int = 0, |
| 108 | + tree_roots: Iterable[int], |
| 109 | +) -> Var: |
| 110 | + r""" |
| 111 | + Tree Ensemble operator. Returns the regressed values for each input in a |
| 112 | + batch. Inputs have dimensions ``[N, F]`` where ``N`` is the input batch |
| 113 | + size and ``F`` is the number of input features. Outputs have dimensions |
| 114 | + ``[N, num_targets]`` where ``N`` is the batch size and ``num_targets`` |
| 115 | + is the number of targets, which is a configurable attribute. |
| 116 | +
|
| 117 | + :: |
| 118 | +
|
| 119 | + The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves. |
| 120 | + The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position. |
| 121 | +
|
| 122 | + All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length. |
| 123 | +
|
| 124 | + This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes. |
| 125 | + The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented. |
| 126 | + The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class. |
| 127 | + To encode class labels, a `LabelEncoder` or `GatherND` operator may be used. |
| 128 | +
|
| 129 | + Parameters |
| 130 | + ========== |
| 131 | + X |
| 132 | + Type T. |
| 133 | + Input of shape [Batch Size, Number of Features] |
| 134 | + aggregate_function |
| 135 | + Attribute. |
| 136 | + Defines how to aggregate leaf values within a target. One of 'AVERAGE' |
| 137 | + (0) 'SUM' (1) 'MIN' (2) 'MAX (3) defaults to 'SUM' (1) |
| 138 | + leaf_targetids |
| 139 | + Attribute. |
| 140 | + The index of the target that this leaf contributes to (this must be in |
| 141 | + range ``[0, n_targets)``). |
| 142 | + leaf_weights |
| 143 | + Attribute. |
| 144 | + The weight for each leaf. |
| 145 | + membership_values |
| 146 | + Attribute. |
| 147 | + Members to test membership of for each set membership node. List all of |
| 148 | + the members to test again in the order that the 'BRANCH_MEMBER' mode |
| 149 | + appears in ``node_modes``, delimited by ``NaN``\ s. Will have the same |
| 150 | + number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be |
| 151 | + omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes. |
| 152 | + n_targets |
| 153 | + Attribute. |
| 154 | + The total number of targets. |
| 155 | + nodes_falseleafs |
| 156 | + Attribute. |
| 157 | + 1 if false branch is leaf for each node and 0 if an interior node. To |
| 158 | + represent a tree that is a leaf (only has one node), one can do so by |
| 159 | + having a single ``nodes_*`` entry with true and false branches |
| 160 | + referencing the same ``leaf_*`` entry |
| 161 | + nodes_falsenodeids |
| 162 | + Attribute. |
| 163 | + If ``nodes_falseleafs`` is false at an entry, this represents the |
| 164 | + position of the false branch node. This position can be used to index |
| 165 | + into a ``nodes_*`` entry. If ``nodes_falseleafs`` is false, it is an |
| 166 | + index into the leaf\_\* attributes. |
| 167 | + nodes_featureids |
| 168 | + Attribute. |
| 169 | + Feature id for each node. |
| 170 | + nodes_hitrates |
| 171 | + Attribute. |
| 172 | + Popularity of each node, used for performance and may be omitted. |
| 173 | + nodes_missing_value_tracks_true |
| 174 | + Attribute. |
| 175 | + For each node, define whether to follow the true branch (if attribute |
| 176 | + value is 1) or false branch (if attribute value is 0) in the presence of |
| 177 | + a NaN input feature. This attribute may be left undefined and the |
| 178 | + default value is false (0) for all nodes. |
| 179 | + nodes_modes |
| 180 | + Attribute. |
| 181 | + The comparison operation performed by the node. This is encoded as an |
| 182 | + enumeration of 0 ('BRANCH_LEQ'), 1 ('BRANCH_LT'), 2 ('BRANCH_GTE'), 3 |
| 183 | + ('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and 6 |
| 184 | + ('BRANCH_MEMBER'). Note this is a tensor of type uint8. |
| 185 | + nodes_splits |
| 186 | + Attribute. |
| 187 | + Thresholds to do the splitting on for each node with mode that is not |
| 188 | + 'BRANCH_MEMBER'. |
| 189 | + nodes_trueleafs |
| 190 | + Attribute. |
| 191 | + 1 if true branch is leaf for each node and 0 an interior node. To |
| 192 | + represent a tree that is a leaf (only has one node), one can do so by |
| 193 | + having a single ``nodes_*`` entry with true and false branches |
| 194 | + referencing the same ``leaf_*`` entry |
| 195 | + nodes_truenodeids |
| 196 | + Attribute. |
| 197 | + If ``nodes_trueleafs`` is false at an entry, this represents the |
| 198 | + position of the true branch node. This position can be used to index |
| 199 | + into a ``nodes_*`` entry. If ``nodes_trueleafs`` is false, it is an |
| 200 | + index into the leaf\_\* attributes. |
| 201 | + post_transform |
| 202 | + Attribute. |
| 203 | + Indicates the transform to apply to the score. One of 'NONE' (0), |
| 204 | + 'SOFTMAX' (1), 'LOGISTIC' (2), 'SOFTMAX_ZERO' (3) or 'PROBIT' (4), |
| 205 | + defaults to 'NONE' (0) |
| 206 | + tree_roots |
| 207 | + Attribute. |
| 208 | + Index into ``nodes_*`` for the root of each tree. The tree structure is |
| 209 | + derived from the branching of each node. |
| 210 | +
|
| 211 | + Returns |
| 212 | + ======= |
| 213 | + Y : Var |
| 214 | + Type T. |
| 215 | + Output of shape [Batch Size, Number of targets] |
| 216 | +
|
| 217 | + Notes |
| 218 | + ===== |
| 219 | + Signature: ``ai.onnx.ml@5::TreeEnsemble``. |
| 220 | +
|
| 221 | + Type constraints: |
| 222 | + - T: `tensor(double)`, `tensor(float)`, `tensor(float16)` |
| 223 | + """ |
| 224 | + return _TreeEnsemble( |
| 225 | + _TreeEnsemble.Attributes( |
| 226 | + aggregate_function=AttrInt64(aggregate_function, name="aggregate_function"), |
| 227 | + leaf_targetids=AttrInt64s(leaf_targetids, name="leaf_targetids"), |
| 228 | + leaf_weights=AttrTensor(leaf_weights, name="leaf_weights"), |
| 229 | + membership_values=AttrTensor.maybe( |
| 230 | + membership_values, name="membership_values" |
| 231 | + ), |
| 232 | + n_targets=AttrInt64.maybe(n_targets, name="n_targets"), |
| 233 | + nodes_falseleafs=AttrInt64s(nodes_falseleafs, name="nodes_falseleafs"), |
| 234 | + nodes_falsenodeids=AttrInt64s( |
| 235 | + nodes_falsenodeids, name="nodes_falsenodeids" |
| 236 | + ), |
| 237 | + nodes_featureids=AttrInt64s(nodes_featureids, name="nodes_featureids"), |
| 238 | + nodes_hitrates=AttrTensor.maybe(nodes_hitrates, name="nodes_hitrates"), |
| 239 | + nodes_missing_value_tracks_true=AttrInt64s.maybe( |
| 240 | + nodes_missing_value_tracks_true, name="nodes_missing_value_tracks_true" |
| 241 | + ), |
| 242 | + nodes_modes=AttrTensor(nodes_modes, name="nodes_modes"), |
| 243 | + nodes_splits=AttrTensor(nodes_splits, name="nodes_splits"), |
| 244 | + nodes_trueleafs=AttrInt64s(nodes_trueleafs, name="nodes_trueleafs"), |
| 245 | + nodes_truenodeids=AttrInt64s(nodes_truenodeids, name="nodes_truenodeids"), |
| 246 | + post_transform=AttrInt64(post_transform, name="post_transform"), |
| 247 | + tree_roots=AttrInt64s(tree_roots, name="tree_roots"), |
| 248 | + ), |
| 249 | + _TreeEnsemble.Inputs( |
| 250 | + X=X, |
| 251 | + ), |
| 252 | + ).outputs.Y |
| 253 | + |
| 254 | + |
| 255 | +_OPERATORS = { |
| 256 | + "ArrayFeatureExtractor": _ArrayFeatureExtractor, |
| 257 | + "Binarizer": _Binarizer, |
| 258 | + "CastMap": _CastMap, |
| 259 | + "CategoryMapper": _CategoryMapper, |
| 260 | + "DictVectorizer": _DictVectorizer, |
| 261 | + "FeatureVectorizer": _FeatureVectorizer, |
| 262 | + "Imputer": _Imputer, |
| 263 | + "LabelEncoder": _LabelEncoder, |
| 264 | + "LinearClassifier": _LinearClassifier, |
| 265 | + "LinearRegressor": _LinearRegressor, |
| 266 | + "Normalizer": _Normalizer, |
| 267 | + "OneHotEncoder": _OneHotEncoder, |
| 268 | + "SVMClassifier": _SVMClassifier, |
| 269 | + "SVMRegressor": _SVMRegressor, |
| 270 | + "Scaler": _Scaler, |
| 271 | + "TreeEnsemble": _TreeEnsemble, |
| 272 | + "ZipMap": _ZipMap, |
| 273 | +} |
| 274 | + |
| 275 | +_CONSTRUCTORS = { |
| 276 | + "ArrayFeatureExtractor": array_feature_extractor, |
| 277 | + "Binarizer": binarizer, |
| 278 | + "CastMap": cast_map, |
| 279 | + "CategoryMapper": category_mapper, |
| 280 | + "DictVectorizer": dict_vectorizer, |
| 281 | + "FeatureVectorizer": feature_vectorizer, |
| 282 | + "Imputer": imputer, |
| 283 | + "LabelEncoder": label_encoder, |
| 284 | + "LinearClassifier": linear_classifier, |
| 285 | + "LinearRegressor": linear_regressor, |
| 286 | + "Normalizer": normalizer, |
| 287 | + "OneHotEncoder": one_hot_encoder, |
| 288 | + "SVMClassifier": svmclassifier, |
| 289 | + "SVMRegressor": svmregressor, |
| 290 | + "Scaler": scaler, |
| 291 | + "TreeEnsemble": tree_ensemble, |
| 292 | + "ZipMap": zip_map, |
| 293 | +} |
| 294 | + |
| 295 | +__all__ = [fun.__name__ for fun in _CONSTRUCTORS.values()] |
0 commit comments