2
2
3
3
import contextlib
4
4
import dataclasses
5
+ import functools
5
6
import typing
6
7
import warnings
7
8
import weakref
@@ -709,7 +710,7 @@ def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
709
710
def meta_storage (
710
711
self ,
711
712
s : MetaStorageDesc ,
712
- callback : Callable [[Callable [[], torch . Tensor ] ], _TensorT ],
713
+ callback : Callable [[Callable ], _TensorT ],
713
714
) -> torch .UntypedStorage :
714
715
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
715
716
# Need to make sure to resize the meta storage too.
@@ -734,7 +735,9 @@ def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
734
735
return typing .cast (_TensorT , t )
735
736
736
737
@classmethod
737
- def _identity_callable (cls , t : Callable [[], torch .Tensor ]) -> _TensorT :
738
+ def _identity_callable (
739
+ cls , t : Callable [[], torch .Tensor ], device : torch .device
740
+ ) -> _TensorT :
738
741
return cls ._checked_cast_tensor_t (t ())
739
742
740
743
@classmethod
@@ -756,10 +759,11 @@ def meta_tensor(
756
759
self ,
757
760
t : MetaTensorDesc ,
758
761
shape_env : Optional [ShapeEnv ],
759
- callback : Callable [[Callable [[], torch . Tensor ] ], _TensorT ],
762
+ callback : Callable [[Callable ], _TensorT ],
760
763
source : Optional [Source ],
761
764
symbolic_context : Optional [SymbolicContext ],
762
765
) -> _TensorT :
766
+ callback = functools .partial (callback , device = t .device ) # type: ignore[call-arg]
763
767
if source is None :
764
768
from torch ._dynamo .source import ConstantSource
765
769
@@ -905,7 +909,7 @@ def _empty_create_subclass(
905
909
symbolic_context : Optional [
906
910
torch .fx .experimental .symbolic_shapes .SymbolicContext
907
911
],
908
- callback : Callable [[Callable [[], torch . Tensor ] ], _TensorT ],
912
+ callback : Callable [[Callable ], _TensorT ],
909
913
source : torch ._guards .Source ,
910
914
) -> _TensorT :
911
915
# We are hitting plain meta_desc tensor so actually
@@ -933,12 +937,15 @@ def _empty_create_subclass(
933
937
)
934
938
935
939
current_source = AttrSource (source , attr )
940
+ inner_callback = functools .partial (
941
+ callback , device = meta_tensor_desc .device # type: ignore[call-arg]
942
+ )
936
943
new_empty_tensor = _empty_create_subclass (
937
944
meta_tensor_desc ,
938
945
meta_tensor_desc .size ,
939
946
meta_tensor_desc .stride ,
940
947
current_context ,
941
- callback ,
948
+ inner_callback ,
942
949
current_source ,
943
950
)
944
951
inner_tensors [attr ] = new_empty_tensor
@@ -975,7 +982,7 @@ def all_dynamic_symbolic_context(
975
982
t : MetaTensorDesc ,
976
983
source : torch ._guards .Source ,
977
984
shape_env : Optional [torch .fx .experimental .symbolic_shapes .ShapeEnv ],
978
- callback : Callable [[Callable [[], torch . Tensor ] ], _TensorT ],
985
+ callback : Callable [[Callable ], _TensorT ],
979
986
) -> torch .fx .experimental .symbolic_shapes .SymbolicContext :
980
987
from torch ._dynamo .source import AttrSource
981
988
from torch .fx .experimental .symbolic_shapes import (
@@ -1137,7 +1144,7 @@ def tensor_visitor_fn(
1137
1144
shape_env : Optional [
1138
1145
torch .fx .experimental .symbolic_shapes .ShapeEnv
1139
1146
] = shape_env ,
1140
- callback : Callable [[Callable [[], torch . Tensor ] ], _TensorT ] = callback , # type: ignore[assignment]
1147
+ callback : Callable [[Callable ], _TensorT ] = callback , # type: ignore[assignment]
1141
1148
) -> torch .Tensor :
1142
1149
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
1143
1150
if visited_t is None :
@@ -1723,17 +1730,17 @@ def __call__(
1723
1730
t : torch .Tensor ,
1724
1731
shape_env : Optional [ShapeEnv ] = None ,
1725
1732
* ,
1726
- callback : Optional [Callable [[Callable [[], torch . Tensor ] ], _TensorT ]] = None ,
1733
+ callback : Optional [Callable [[Callable ], _TensorT ]] = None ,
1727
1734
source : Optional [Source ] = None ,
1728
1735
symbolic_context : Optional [SymbolicContext ] = None ,
1729
1736
# Controls whether or not we should dump the tensor metadata to structured logs
1730
1737
# when source is not None. Because we refakify after Dynamo is done,
1731
1738
# we don't want to dump info again from AOTAutograd, it is redundant.
1732
1739
trace : bool = True ,
1733
1740
) -> _TensorT :
1734
- callback_ : Callable [[Callable [[], torch . Tensor ] ], _TensorT ]
1741
+ callback_ : Callable [[Callable ], _TensorT ]
1735
1742
if callback is None :
1736
- callback_ = self ._identity_callable
1743
+ callback_ = self ._identity_callable # type: ignore[assignment]
1737
1744
else :
1738
1745
callback_ = callback
1739
1746
# TODO: zero tensors? We appear to have eliminated them by
0 commit comments