23
23
24
24
class Attr (ABC , Generic [T ]):
25
25
_value : Union [T , "_Ref[T]" ]
26
+ _name : str
27
+ _cached_onnx : Optional [AttributeProto ]
26
28
27
- def __init__ (self , value : Union [T , "_Ref[T]" ]):
29
+ def __init__ (self , value : Union [T , "_Ref[T]" ], name : str ):
28
30
self ._value = value
31
+ self ._name = name
32
+ self ._cached_onnx = None
33
+
29
34
self ._validate ()
30
35
36
+ def deref (self ) -> "Attr" :
37
+ if isinstance (self ._value , _Ref ):
38
+ return type (self )(self .value , self ._name )
39
+ else :
40
+ return self
41
+
31
42
@classmethod
32
- def maybe (cls : Type [AttrT ], value : Optional [T ]) -> Optional [AttrT ]:
33
- return cls (value ) if value is not None else None
43
+ def maybe (cls : Type [AttrT ], value : Optional [T ], name : str ) -> Optional [AttrT ]:
44
+ return cls (value , name ) if value is not None else None
34
45
35
46
@property
36
47
def value (self ) -> T :
@@ -41,7 +52,7 @@ def value(self) -> T:
41
52
42
53
def _validate (self ):
43
54
try :
44
- type_in_onnx = self ._to_onnx_deref ( "dummy" ).type
55
+ type_in_onnx = self ._to_onnx ( ).type
45
56
except Exception as e :
46
57
# Likely an error from within onnx/protobuf, such as:
47
58
# 1) AttributeError: 'int' object has no attribute 'encode'
@@ -52,18 +63,19 @@ def _validate(self):
52
63
if type_in_onnx != self ._attribute_proto_type :
53
64
raise self ._get_pretty_type_exception ()
54
65
55
- def _to_onnx (self , key : str ) -> AttributeProto :
66
+ def _to_onnx (self ) -> AttributeProto :
56
67
if isinstance (self ._value , _Ref ):
57
- return self ._value ._to_onnx (key )
58
- return self ._to_onnx_deref (key )
68
+ return self ._value ._to_onnx ()
69
+ if self ._cached_onnx is None :
70
+ self ._cached_onnx = self ._to_onnx_deref ()
71
+ return self ._cached_onnx
59
72
60
73
@property
61
74
@abc .abstractmethod
62
75
def _attribute_proto_type (self ) -> int :
63
76
raise NotImplementedError ()
64
77
65
- @abc .abstractmethod
66
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
78
+ def _to_onnx_deref (self ) -> AttributeProto :
67
79
"""Conversion method for the dereferenced case."""
68
80
raise NotImplementedError ()
69
81
@@ -87,57 +99,58 @@ class _Ref(Generic[T]):
87
99
88
100
_concrete : Attr [T ]
89
101
90
- def __init__ (self , concrete : Attr [T ], outer_name : str ):
102
+ def __init__ (self , concrete : Attr [T ], outer_name : str , name : str ):
91
103
self ._concrete = concrete
92
104
self ._outer_name = outer_name
105
+ self ._name = name
93
106
94
107
def copy (self ) -> "_Ref[T]" :
95
108
return self
96
109
97
- def _to_onnx (self , key : str ) -> AttributeProto :
98
- parent_type = self ._concrete ._to_onnx (key ).type
110
+ def _to_onnx (self ) -> AttributeProto :
111
+ parent_type = self ._concrete ._to_onnx ().type
99
112
return AttributeProto (
100
- name = key , ref_attr_name = self ._outer_name , type = parent_type
113
+ name = self . _name , ref_attr_name = self ._outer_name , type = parent_type
101
114
)
102
115
103
116
104
117
class AttrFloat32 (Attr [float ]):
105
118
_attribute_proto_type = AttributeProto .FLOAT
106
119
107
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
120
+ def _to_onnx_deref (self ) -> AttributeProto :
108
121
if isinstance (self .value , int ):
109
- return make_attribute (key , float (self .value ))
110
- return make_attribute (key , self .value )
122
+ return make_attribute (self . _name , float (self .value ))
123
+ return make_attribute (self . _name , self .value )
111
124
112
125
113
126
class AttrInt64 (Attr [int ]):
114
127
_attribute_proto_type = AttributeProto .INT
115
128
116
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
117
- return make_attribute (key , self .value )
129
+ def _to_onnx_deref (self ) -> AttributeProto :
130
+ return make_attribute (self . _name , self .value )
118
131
119
132
120
133
class AttrString (Attr [str ]):
121
134
_attribute_proto_type = AttributeProto .STRING
122
135
123
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
124
- return make_attribute (key , self .value )
136
+ def _to_onnx_deref (self ) -> AttributeProto :
137
+ return make_attribute (self . _name , self .value )
125
138
126
139
127
140
class AttrTensor (Attr [np .ndarray ]):
128
141
_attribute_proto_type = AttributeProto .TENSOR
129
142
130
- def __init__ (self , value : Union [np .ndarray , _Ref [np .ndarray ]]):
131
- super ().__init__ (value .copy ())
143
+ def __init__ (self , value : Union [np .ndarray , _Ref [np .ndarray ]], name : str ):
144
+ super ().__init__ (value .copy (), name )
132
145
133
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
134
- return make_attribute (key , from_array (self .value ))
146
+ def _to_onnx_deref (self ) -> AttributeProto :
147
+ return make_attribute (self . _name , from_array (self .value ))
135
148
136
149
137
150
class AttrType (Attr [_type_system .Type ]):
138
151
_attribute_proto_type = AttributeProto .TYPE_PROTO
139
152
140
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
153
+ def _to_onnx_deref (self ) -> AttributeProto :
141
154
value = self .value # for type-checkers with limited property support
142
155
if isinstance (value , _type_system .Tensor ):
143
156
type_proto = make_tensor_type_proto (
@@ -150,7 +163,7 @@ def _to_onnx_deref(self, key: str) -> AttributeProto:
150
163
type_proto = make_optional_type_proto (value .elem_type ._to_onnx ())
151
164
else :
152
165
raise NotImplementedError ()
153
- return make_attribute (key , type_proto )
166
+ return make_attribute (self . _name , type_proto )
154
167
155
168
156
169
class AttrDtype (Attr [npt .DTypeLike ]):
@@ -161,8 +174,8 @@ class AttrDtype(Attr[npt.DTypeLike]):
161
174
def _validate (self ):
162
175
dtype_to_tensor_type (self .value )
163
176
164
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
165
- return make_attribute (key , dtype_to_tensor_type (self .value ))
177
+ def _to_onnx_deref (self ) -> AttributeProto :
178
+ return make_attribute (self . _name , dtype_to_tensor_type (self .value ))
166
179
167
180
168
181
class AttrGraph (Attr [Any ]):
@@ -176,24 +189,30 @@ def _validate(self):
176
189
f"Expected value of type `spox.graph.Graph found `{ type (self .value )} `"
177
190
)
178
191
179
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
192
+ def _to_onnx_deref (self ) -> AttributeProto :
180
193
raise TypeError (
181
194
"Graph attributes must be built using the `build_subgraph` callback in `Node.to_onnx`."
182
195
)
183
196
184
197
185
198
class _AttrIterable (Attr [Tuple [S , ...]], ABC ):
186
- def __init__ (self , value : Union [Iterable [S ], _Ref [Tuple [S , ...]]]):
187
- super ().__init__ (value if isinstance (value , _Ref ) else tuple (value ))
199
+ def __init__ (self , value : Union [Iterable [S ], _Ref [Tuple [S , ...]]], name : str ):
200
+ super ().__init__ (
201
+ value = value if isinstance (value , _Ref ) else tuple (value ), name = name
202
+ )
188
203
189
204
@classmethod
190
205
def maybe (
191
- cls : Type [AttrIterableT ], value : Optional [Iterable [S ]]
206
+ cls : Type [AttrIterableT ],
207
+ value : Optional [Iterable [S ]],
208
+ name : str ,
192
209
) -> Optional [AttrIterableT ]:
193
- return cls (tuple (value )) if value is not None else None
210
+ return cls (tuple (value ), name ) if value is not None else None
194
211
195
- def _to_onnx_deref (self , key : str ) -> AttributeProto :
196
- return make_attribute (key , self .value , attr_type = self ._attribute_proto_type )
212
+ def _to_onnx_deref (self ) -> AttributeProto :
213
+ return make_attribute (
214
+ self ._name , self .value , attr_type = self ._attribute_proto_type
215
+ )
197
216
198
217
199
218
class AttrFloat32s (_AttrIterable [float ]):
0 commit comments