19
19
#include < c10/util/Logging.h>
20
20
#include < c10/util/Optional.h>
21
21
#include < c10/util/accumulate.h>
22
+ #include < c10/util/intrusive_ptr.h>
22
23
#include < c10/util/irange.h>
23
24
#include < c10/util/python_stub.h>
24
25
#include < c10/util/safe_numerics.h>
@@ -217,6 +218,18 @@ is_channels_last_3d
217
218
is_non_overlapping_and_dense
218
219
#endif
219
220
221
+ /* *
222
+ * This structure is intended to hold additional metadata of the specific device
223
+ *backend
224
+ **/
225
+ struct C10_API BackendMeta : intrusive_ptr_target {
226
+ virtual ~BackendMeta (){};
227
+ virtual intrusive_ptr<BackendMeta> clone (
228
+ const intrusive_ptr<BackendMeta>& ptr) const {
229
+ return ptr;
230
+ }
231
+ };
232
+
220
233
struct C10_API ExtraMeta {
221
234
SymDimVector sizes_ = {0 };
222
235
SymDimVector strides_ = {1 };
@@ -229,6 +242,7 @@ struct C10_API ExtraMeta {
229
242
SymBool is_channels_last_3d_{false };
230
243
SymBool is_non_overlapping_and_dense_{true };
231
244
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr ;
245
+ intrusive_ptr<c10::BackendMeta> backend_meta_;
232
246
233
247
ExtraMeta () = default ;
234
248
@@ -243,7 +257,8 @@ struct C10_API ExtraMeta {
243
257
SymBool is_channels_last,
244
258
SymBool is_channels_last_3d,
245
259
SymBool is_non_overlapping_and_dense,
246
- std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta)
260
+ std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta,
261
+ intrusive_ptr<c10::BackendMeta> backend_meta)
247
262
: sizes_(std::move(sizes)),
248
263
strides_ (std::move(strides)),
249
264
numel_(std::move(numel)),
@@ -255,7 +270,8 @@ struct C10_API ExtraMeta {
255
270
is_channels_last_(std::move(is_channels_last)),
256
271
is_channels_last_3d_(std::move(is_channels_last_3d)),
257
272
is_non_overlapping_and_dense_(std::move(is_non_overlapping_and_dense)),
258
- named_tensor_meta_(std::move(named_tensor_meta)) {}
273
+ named_tensor_meta_(std::move(named_tensor_meta)),
274
+ backend_meta_(backend_meta) {}
259
275
260
276
std::unique_ptr<ExtraMeta> clone () const {
261
277
return std::make_unique<ExtraMeta>(
@@ -269,7 +285,8 @@ struct C10_API ExtraMeta {
269
285
is_channels_last_,
270
286
is_channels_last_3d_,
271
287
is_non_overlapping_and_dense_,
272
- named_tensor_meta_ ? named_tensor_meta_->clone () : nullptr );
288
+ named_tensor_meta_ ? named_tensor_meta_->clone () : nullptr ,
289
+ backend_meta_ ? backend_meta_->clone (backend_meta_) : nullptr );
273
290
}
274
291
};
275
292
@@ -1576,6 +1593,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1576
1593
return data_type_.itemsize ();
1577
1594
}
1578
1595
1596
+ void set_backend_meta (intrusive_ptr<c10::BackendMeta> backend_meta) {
1597
+ if (!extra_meta_) {
1598
+ extra_meta_ = std::make_unique<ExtraMeta>();
1599
+ }
1600
+ extra_meta_->backend_meta_ = std::move (backend_meta);
1601
+ }
1602
+
1603
+ c10::BackendMeta* get_backend_meta () {
1604
+ if (!extra_meta_) {
1605
+ return nullptr ;
1606
+ }
1607
+ return extra_meta_->backend_meta_ .get ();
1608
+ }
1609
+
1610
+ intrusive_ptr<c10::BackendMeta> get_backend_meta_intrusive_ptr () const {
1611
+ if (!extra_meta_) {
1612
+ return nullptr ;
1613
+ }
1614
+ return extra_meta_->backend_meta_ ;
1615
+ }
1616
+
1579
1617
protected:
1580
1618
/* *
1581
1619
* Returns the human-readable name of the actual type of this object (e.g.,
0 commit comments