Skip to content

Commit e41a0b3

Browse files
soulitzerpytorchmergebot
authored andcommitted
Allow Fakified subclass to have different device for inner and outer tensor (pytorch#141839)
Previously if a wrapper tensor subclass is fakified, the inner tensors would end up having the same device as the outer tensor. This PR makes it so that inner and outer tensors can have different devices. See OffloadTensor PR https://github.com/pytorch/pytorch/pull/141840/files#diff-3bc0cf540b694f4ec0a3749f78b047456657a53a5657e495ffb68e5970c5fdaaR1955 for an application. A simpler test has been added in this PR. This is technically bc-breaking because now the callback passed to MetaConverter needs to accept an extra argument, but no one external should be using this anyway? Pull Request resolved: pytorch#141839 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#141166
1 parent 9830e7b commit e41a0b3

File tree

3 files changed

+82
-15
lines changed

3 files changed

+82
-15
lines changed

test/test_fake_tensor.py

+54
Original file line numberDiff line numberDiff line change
@@ -1951,6 +1951,60 @@ def test_inference_mode(self):
19511951
extract_tensor_metadata(res4),
19521952
)
19531953

1954+
1955+
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1956+
def test_wrapper_tensor_subclass_different_device(self):
1957+
class DifferentDeviceTensor(torch.Tensor):
1958+
@staticmethod
1959+
def __new__(cls, a):
1960+
kwargs = {}
1961+
kwargs["strides"] = a.stride()
1962+
kwargs["storage_offset"] = a.storage_offset()
1963+
kwargs["device"] = torch.device("cpu")
1964+
kwargs["layout"] = a.layout
1965+
kwargs["requires_grad"] = a.requires_grad
1966+
kwargs["dtype"] = a.dtype
1967+
out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs)
1968+
return out
1969+
1970+
def __init__(self, a):
1971+
self.inner_tensor = a
1972+
1973+
def __repr__(self):
1974+
return f"DifferentDeviceTensor({repr(self.inner_tensor)})"
1975+
1976+
def __tensor_flatten__(self):
1977+
return ["inner_tensor"], None
1978+
1979+
@staticmethod
1980+
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
1981+
assert meta is None
1982+
return DifferentDeviceTensor(inner_tensors["inner_tensor"])
1983+
1984+
@classmethod
1985+
def __torch_dispatch__(cls, func, types, args, kwargs):
1986+
if kwargs is None:
1987+
kwargs = {}
1988+
args = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, args)
1989+
kwargs = pytree.tree_map_only(DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs)
1990+
# Returns unwrapped tensor
1991+
return func(*args, **kwargs)
1992+
1993+
a = torch.ones(2, 2, 768, device="cuda")
1994+
wrapped_a = DifferentDeviceTensor(a)
1995+
1996+
# Outer Tensor is on cpu, inner is on cuda
1997+
self.assertTrue(wrapped_a.is_cpu)
1998+
self.assertFalse(wrapped_a.inner_tensor.is_cpu)
1999+
2000+
with FakeTensorMode() as fake_mode:
2001+
fake_wrapped_a = fake_mode.from_tensor(wrapped_a)
2002+
2003+
self.assertTrue(fake_wrapped_a.is_cpu)
2004+
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
2005+
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
2006+
2007+
19542008
def test_cache_tuple_outputs(self):
19552009
"""
19562010
Test to check that ops with tuple outputs work.

torch/_subclasses/fake_tensor.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -354,14 +354,20 @@ def from_real_tensor(
354354
maybe_memo = self._get_memo(t)
355355
if maybe_memo is not None:
356356
return maybe_memo
357-
existing_device = t.device
358357
# not yet supported in metatensors
359358
if t.is_quantized:
360359
raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
361360
if type(t) is torch.nn.Parameter:
362361
assert not make_constant
363362

364-
def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
363+
constant = t if make_constant else None
364+
365+
# This callback is used by both subclass and inner tensors. Require the
366+
# caller to explicitly specify the device in case outer and inner tensors
367+
# have different devices.
368+
def mk_fake_tensor(
369+
make_meta_t: Callable[[], object], device: torch.device
370+
) -> FakeTensor:
365371
# NB: don't use in_kernel_invocation_manager. to
366372
# ensure FakeTensor can internally do constant computation
367373
# as necessary. Invocation manager is "more correct" as
@@ -373,16 +379,16 @@ def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor:
373379
return FakeTensor(
374380
fake_mode,
375381
make_meta_t(),
376-
existing_device,
382+
device,
377383
# TODO: callback might be used in recursive contexts, in
378384
# which case using t is wrong! BUG!
379-
constant=t if make_constant else None,
385+
constant=constant,
380386
)
381387

382388
out = self.meta_converter(
383389
t,
384390
shape_env=shape_env,
385-
callback=mk_fake_tensor,
391+
callback=mk_fake_tensor, # type: ignore[arg-type]
386392
source=source,
387393
symbolic_context=symbolic_context,
388394
trace=trace,

torch/_subclasses/meta_utils.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import dataclasses
5+
import functools
56
import typing
67
import warnings
78
import weakref
@@ -709,7 +710,7 @@ def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
709710
def meta_storage(
710711
self,
711712
s: MetaStorageDesc,
712-
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
713+
callback: Callable[[Callable], _TensorT],
713714
) -> torch.UntypedStorage:
714715
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
715716
# Need to make sure to resize the meta storage too.
@@ -734,7 +735,9 @@ def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
734735
return typing.cast(_TensorT, t)
735736

736737
@classmethod
737-
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
738+
def _identity_callable(
739+
cls, t: Callable[[], torch.Tensor], device: torch.device
740+
) -> _TensorT:
738741
return cls._checked_cast_tensor_t(t())
739742

740743
@classmethod
@@ -756,10 +759,11 @@ def meta_tensor(
756759
self,
757760
t: MetaTensorDesc,
758761
shape_env: Optional[ShapeEnv],
759-
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
762+
callback: Callable[[Callable], _TensorT],
760763
source: Optional[Source],
761764
symbolic_context: Optional[SymbolicContext],
762765
) -> _TensorT:
766+
callback = functools.partial(callback, device=t.device) # type: ignore[call-arg]
763767
if source is None:
764768
from torch._dynamo.source import ConstantSource
765769

@@ -905,7 +909,7 @@ def _empty_create_subclass(
905909
symbolic_context: Optional[
906910
torch.fx.experimental.symbolic_shapes.SymbolicContext
907911
],
908-
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
912+
callback: Callable[[Callable], _TensorT],
909913
source: torch._guards.Source,
910914
) -> _TensorT:
911915
# We are hitting plain meta_desc tensor so actually
@@ -933,12 +937,15 @@ def _empty_create_subclass(
933937
)
934938

935939
current_source = AttrSource(source, attr)
940+
inner_callback = functools.partial(
941+
callback, device=meta_tensor_desc.device # type: ignore[call-arg]
942+
)
936943
new_empty_tensor = _empty_create_subclass(
937944
meta_tensor_desc,
938945
meta_tensor_desc.size,
939946
meta_tensor_desc.stride,
940947
current_context,
941-
callback,
948+
inner_callback,
942949
current_source,
943950
)
944951
inner_tensors[attr] = new_empty_tensor
@@ -975,7 +982,7 @@ def all_dynamic_symbolic_context(
975982
t: MetaTensorDesc,
976983
source: torch._guards.Source,
977984
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
978-
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
985+
callback: Callable[[Callable], _TensorT],
979986
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
980987
from torch._dynamo.source import AttrSource
981988
from torch.fx.experimental.symbolic_shapes import (
@@ -1137,7 +1144,7 @@ def tensor_visitor_fn(
11371144
shape_env: Optional[
11381145
torch.fx.experimental.symbolic_shapes.ShapeEnv
11391146
] = shape_env,
1140-
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
1147+
callback: Callable[[Callable], _TensorT] = callback, # type: ignore[assignment]
11411148
) -> torch.Tensor:
11421149
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
11431150
if visited_t is None:
@@ -1723,17 +1730,17 @@ def __call__(
17231730
t: torch.Tensor,
17241731
shape_env: Optional[ShapeEnv] = None,
17251732
*,
1726-
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
1733+
callback: Optional[Callable[[Callable], _TensorT]] = None,
17271734
source: Optional[Source] = None,
17281735
symbolic_context: Optional[SymbolicContext] = None,
17291736
# Controls whether or not we should dump the tensor metadata to structured logs
17301737
# when source is not None. Because we refakify after Dynamo is done,
17311738
# we don't want to dump info again from AOTAutograd, it is redundant.
17321739
trace: bool = True,
17331740
) -> _TensorT:
1734-
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
1741+
callback_: Callable[[Callable], _TensorT]
17351742
if callback is None:
1736-
callback_ = self._identity_callable
1743+
callback_ = self._identity_callable # type: ignore[assignment]
17371744
else:
17381745
callback_ = callback
17391746
# TODO: zero tensors? We appear to have eliminated them by

0 commit comments

Comments
 (0)