Skip to content

Commit 2d9b2bc

Browse files
ppiskorskipytorchmergebot
authored andcommitted
Extend TensorImpl with BackendMeta (pytorch#97429)
BackendMeta offers a binary interface for the backend to attach arbitrary data to TensorImpl. TensorImpl has exactly one "slot" for backend metadata, however backend is free to compose any structure that is opaque to the framework beyond iheriting standard BackendMeta base. Change-Id: I670fcdd16dd1c2b00f7eaa1cbc5b5dfea59a6221 Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#97429 Approved by: https://github.com/ezyang
1 parent dd50337 commit 2d9b2bc

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

c10/core/TensorImpl.h

+41-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <c10/util/Logging.h>
2020
#include <c10/util/Optional.h>
2121
#include <c10/util/accumulate.h>
22+
#include <c10/util/intrusive_ptr.h>
2223
#include <c10/util/irange.h>
2324
#include <c10/util/python_stub.h>
2425
#include <c10/util/safe_numerics.h>
@@ -217,6 +218,18 @@ is_channels_last_3d
217218
is_non_overlapping_and_dense
218219
#endif
219220

221+
/**
222+
* This structure is intended to hold additional metadata of the specific device
223+
*backend
224+
**/
225+
struct C10_API BackendMeta : intrusive_ptr_target {
226+
virtual ~BackendMeta(){};
227+
virtual intrusive_ptr<BackendMeta> clone(
228+
const intrusive_ptr<BackendMeta>& ptr) const {
229+
return ptr;
230+
}
231+
};
232+
220233
struct C10_API ExtraMeta {
221234
SymDimVector sizes_ = {0};
222235
SymDimVector strides_ = {1};
@@ -229,6 +242,7 @@ struct C10_API ExtraMeta {
229242
SymBool is_channels_last_3d_{false};
230243
SymBool is_non_overlapping_and_dense_{true};
231244
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
245+
intrusive_ptr<c10::BackendMeta> backend_meta_;
232246

233247
ExtraMeta() = default;
234248

@@ -243,7 +257,8 @@ struct C10_API ExtraMeta {
243257
SymBool is_channels_last,
244258
SymBool is_channels_last_3d,
245259
SymBool is_non_overlapping_and_dense,
246-
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta)
260+
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta,
261+
intrusive_ptr<c10::BackendMeta> backend_meta)
247262
: sizes_(std::move(sizes)),
248263
strides_(std::move(strides)),
249264
numel_(std::move(numel)),
@@ -255,7 +270,8 @@ struct C10_API ExtraMeta {
255270
is_channels_last_(std::move(is_channels_last)),
256271
is_channels_last_3d_(std::move(is_channels_last_3d)),
257272
is_non_overlapping_and_dense_(std::move(is_non_overlapping_and_dense)),
258-
named_tensor_meta_(std::move(named_tensor_meta)) {}
273+
named_tensor_meta_(std::move(named_tensor_meta)),
274+
backend_meta_(backend_meta) {}
259275

260276
std::unique_ptr<ExtraMeta> clone() const {
261277
return std::make_unique<ExtraMeta>(
@@ -269,7 +285,8 @@ struct C10_API ExtraMeta {
269285
is_channels_last_,
270286
is_channels_last_3d_,
271287
is_non_overlapping_and_dense_,
272-
named_tensor_meta_ ? named_tensor_meta_->clone() : nullptr);
288+
named_tensor_meta_ ? named_tensor_meta_->clone() : nullptr,
289+
backend_meta_ ? backend_meta_->clone(backend_meta_) : nullptr);
273290
}
274291
};
275292

@@ -1576,6 +1593,27 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
15761593
return data_type_.itemsize();
15771594
}
15781595

1596+
void set_backend_meta(intrusive_ptr<c10::BackendMeta> backend_meta) {
1597+
if (!extra_meta_) {
1598+
extra_meta_ = std::make_unique<ExtraMeta>();
1599+
}
1600+
extra_meta_->backend_meta_ = std::move(backend_meta);
1601+
}
1602+
1603+
c10::BackendMeta* get_backend_meta() {
1604+
if (!extra_meta_) {
1605+
return nullptr;
1606+
}
1607+
return extra_meta_->backend_meta_.get();
1608+
}
1609+
1610+
intrusive_ptr<c10::BackendMeta> get_backend_meta_intrusive_ptr() const {
1611+
if (!extra_meta_) {
1612+
return nullptr;
1613+
}
1614+
return extra_meta_->backend_meta_;
1615+
}
1616+
15791617
protected:
15801618
/**
15811619
* Returns the human-readable name of the actual type of this object (e.g.,

test/cpp/api/tensor.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -1217,3 +1217,44 @@ TEST(TensorTest, ReshapeAlias) {
12171217
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
12181218
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
12191219
}
1220+
1221+
TEST(TensorTest, BackendMetadata) {
1222+
// Tests ability to assign custom backend metadata to tensor.
1223+
1224+
struct CustomBackendMetadata : public c10::BackendMeta {
1225+
mutable bool cloned_{false}; // for testing this field will mutate when
1226+
// clone() is called by shallow_copy_from.
1227+
c10::intrusive_ptr<c10::BackendMeta> clone(
1228+
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
1229+
cloned_ = true;
1230+
return c10::BackendMeta::clone(ptr);
1231+
}
1232+
};
1233+
1234+
at::Tensor y;
1235+
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
1236+
CustomBackendMetadata* custom_tmeta{nullptr};
1237+
1238+
{
1239+
auto x = torch::ones({3, 3});
1240+
auto impl{x.unsafeGetTensorImpl()};
1241+
ASSERT_TRUE(impl != nullptr);
1242+
1243+
tmeta = impl->get_backend_meta_intrusive_ptr();
1244+
ASSERT_TRUE(tmeta == nullptr);
1245+
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
1246+
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
1247+
impl->set_backend_meta(new_tmeta);
1248+
tmeta = impl->get_backend_meta_intrusive_ptr();
1249+
ASSERT_TRUE(tmeta == new_tmeta);
1250+
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
1251+
ASSERT_TRUE(custom_tmeta != nullptr);
1252+
ASSERT_TRUE(custom_tmeta->cloned_ == false);
1253+
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
1254+
}
1255+
1256+
ASSERT_TRUE(
1257+
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
1258+
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
1259+
ASSERT_TRUE(custom_tmeta->cloned_ == true);
1260+
}

0 commit comments

Comments
 (0)