Skip to content

Commit f25c7c9

Browse files
bdhirshpytorchmergebot
authored andcommitted
functionalize storage resizing, minimal ppFSDP traceable forward (pytorch#122434)
More details further down, but first a more high-level description of "how do we functionalize storage resizing" Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)` So given this setup, there are 3 main cases that I think we want to handle: (1) graph input starts with a real storage size, gets resized down to zero in the graph (2) graph input starts with 0 storage size, gets resized up in the graph (3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0 For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations. For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation) For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph. The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor. So effectively, all that this rule does is: (1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue (2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_() I think a good order to look at changes in this PR would be: (1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs (2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op (3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils (4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses (5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?) ------------------ This PR subsumes pytorch#120971. This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below). There are a few things worth calling out: (1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations: ``` param.storage().resize_(all_gather_size) param.copy_(all_gather_buffer) out = aten.matmul(param, param) ``` and functionalizes them into: ``` out = aten.matmul(all_gather_buffer, all_gather_buffer) ``` This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this) (2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why? The final forward graph looks like this: ``` def forward(self, primals_1, primals_2): _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_2 = None getitem = _foreach_copy[0]; _foreach_copy = None mm = torch.ops.aten.mm.default(getitem, getitem); getitem = None t_1 = torch.ops.aten.t.default(primals_1); primals_1 = None return [mm, t_1] ``` Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away). Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage. The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage. The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward. One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless. (3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in pytorch#120971 (4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case (5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR (6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down (7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing Pull Request resolved: pytorch#122434 Approved by: https://github.com/jansel
1 parent f42ea14 commit f25c7c9

File tree

43 files changed

+704
-64
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+704
-64
lines changed

aten/src/ATen/FunctionalStorageImpl.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,16 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
9797
/*resizable=*/true
9898
),
9999
base_(base)
100-
{
100+
{
101+
// SparseTensorImpl has no storage, so we cannot query its nbytes.
102+
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
103+
// Same for XLA
104+
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
105+
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
106+
} else {
107+
original_storage_size_ = -1;
108+
}
109+
curr_storage_size_ = original_storage_size_;
101110
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
102111
}
103112

aten/src/ATen/FunctionalStorageImpl.h

+24
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
105105
frozen_ = true;
106106
}
107107

108+
c10::SymInt get_storage_size(bool before) {
109+
if (before) {
110+
return original_storage_size_;
111+
} else {
112+
return curr_storage_size_;
113+
}
114+
}
115+
108116
~FunctionalStorageImpl() override = default;
109117

110118
void mark_mutation() {
@@ -132,6 +140,15 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
132140
return mutation_counter_ <= mutation_counter_hidden_from_autograd_;
133141
}
134142

143+
void mark_inductor_storage_resize(c10::SymInt new_size) {
144+
inductor_storage_resized_ = true;
145+
curr_storage_size_ = new_size;
146+
}
147+
148+
bool was_inductor_storage_resized() {
149+
return inductor_storage_resized_;
150+
}
151+
135152
private:
136153
// NB: base_ should always point to a tensor BELOW the current
137154
// functionalization layer. This is mainly to avoid reference cycles. e.g.
@@ -172,6 +189,13 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
172189
uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0;
173190
uint64_t mutation_counter_ = 0;
174191
uint64_t mutation_counter_hidden_from_autograd_ = 0;
192+
193+
// Used to tell if:
194+
// (1) There were any storage resizes on a graph input
195+
// (2) The original/curr storage size tell us if these resizes result in a nop
196+
bool inductor_storage_resized_ = false;
197+
c10::SymInt original_storage_size_;
198+
c10::SymInt curr_storage_size_;
175199
};
176200

177201
} // namespace at::functionalization

aten/src/ATen/FunctionalTensorWrapper.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,32 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
276276
set_sizes_and_strides(sizes_, strides_, storage_offset_);
277277
}
278278

279+
void FunctionalTensorWrapper::storage_resize_(c10::SymInt new_size) {
280+
auto curr_storage_size = value_.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
281+
// storage resizing is severely limited: we only support resizing either to zero, or from zero bytes.
282+
TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size);
283+
// The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want
284+
// resize_() calls to actualy emit any ops in the functional graph.
285+
// How does it work?
286+
// Resizing up (old size == 0):
287+
// We do nothing in this case.
288+
// The expection is that for the user code to be valid, the next op that should run against the current tensor "x"
289+
// will be a x.copy_(y) (or similar), that will fully overwrite the data of x.
290+
// If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call
291+
// (otherwise the eager code would be invalid),
292+
// and therefore functionalization will regenerate the aliases off of the result of `x.copy(y)`.
293+
// Resizing down (new size == 0):
294+
// We also do nothing in this case. The assumption is that after resizing a tensor down,
295+
// it is fully unused in the program (unless it is later resized back up first, has data copied in)
296+
// Although it might be saved for backward, which happens in FSDP.
297+
// The expected pattern is that the param will then be resized back up from zero in the backward.
298+
299+
// Mark the tensor as having its storage resized.
300+
// This is so we can detect it for inputs in AOTAutograd and error / emit
301+
// an input mutation resize_() appropriately
302+
functional_storage_impl()->mark_inductor_storage_resize(new_size);
303+
}
304+
279305
void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
280306
// Note [resize_() in functionalization pass]
281307
// resize_() is a special operator in functionalization because it can reallocate its underlying storage.

aten/src/ATen/FunctionalTensorWrapper.h

+13
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
141141
// Custom implementation of self.set_(src)
142142
void set__impl(const FunctionalTensorWrapper* other);
143143

144+
// Custom implementation of resize_storage_bytes_(self, new_size)
145+
void storage_resize_(c10::SymInt new_size);
146+
144147
// Returns whether the current tensor's data was ever mutated
145148
bool has_data_mutation();
146149
//
@@ -150,6 +153,16 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
150153
return was_storage_changed_;
151154
}
152155

156+
c10::SymInt get_storage_size(bool before) {
157+
return functional_storage_impl()->get_storage_size(before);
158+
}
159+
160+
// Returns whether the FunctionalTensor experienced an
161+
// untyped_storage().resize_() call
162+
bool was_inductor_storage_resized() {
163+
return functional_storage_impl()->was_inductor_storage_resized();
164+
}
165+
153166
// The functionalization pass can be used to remove mutations.
154167
// It does so by replacing any mutation op with it's corresponding
155168
// out-of-place op, followed by a call to replace_(). e.g:

aten/src/ATen/FunctionalizeFallbackKernel.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,9 @@ static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
335335
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
336336
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
337337
auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
338+
// See Note [Ordering of resize_() and set_()]
339+
TORCH_CHECK(!self_impl->was_inductor_storage_resized(),
340+
"storage_resize_() followed by set_() in torch.compile is not supported today");
338341
self_impl->set__impl(src_impl);
339342
return self;
340343
}

aten/src/ATen/native/Copy.cpp

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/native/Copy.h>
3+
#include <ATen/native/Copy.h>
34

45
#include <ATen/core/Tensor.h>
56
#include <ATen/Dispatch.h>
@@ -25,8 +26,12 @@
2526
#include <ATen/ops/_copy_from.h>
2627
#include <ATen/ops/_propagate_xla_data.h>
2728
#include <ATen/ops/_propagate_xla_data_native.h>
29+
#include <ATen/ops/copy.h>
2830
#include <ATen/ops/copy_native.h>
31+
#include <ATen/ops/_foreach_copy.h>
32+
#include <ATen/ops/_foreach_copy_native.h>
2933
#include <ATen/ops/empty.h>
34+
#include <ATen/ops/empty_strided.h>
3035
#include <ATen/ops/expand_copy.h>
3136
#endif
3237

@@ -303,15 +308,45 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
303308
return self;
304309
}
305310

311+
Tensor copy_meta(const Tensor& self, const Tensor& src, bool non_blocking) {
312+
// Must directly use self(), so we can dispatch properly is self is a subclass
313+
auto r = clone_preserve_strides(self);
314+
r.copy_(src, non_blocking);
315+
return r;
316+
}
317+
306318
Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
319+
at::Tensor r;
307320
// copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
308321
// (1) It isn't exposed to the frontend (no python bindings)
309322
// (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
310-
auto r = clone_preserve_strides(self);
323+
auto self_storage = self.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
324+
// If self has no real storage, we can't actually clone it.
325+
// Instead, generate an empty tensor with the right sizes/strides, since we should be able to assume
326+
// that copy_() will fully overwrite all data with that of src
327+
if (self_storage->nbytes() == 0) {
328+
r = at::empty_strided(self.sizes(), self.strides());
329+
} else {
330+
r = clone_preserve_strides(self);
331+
}
311332
r.copy_(src, non_blocking);
312333
return r;
313334
}
314335

336+
::std::vector<at::Tensor> _foreach_copy(at::TensorList self, at::TensorList src, bool non_blocking) {
337+
std::vector<at::Tensor> outs;
338+
outs.reserve(self.size());
339+
// This is a very slow implementation, but needs to directly call the copy() kernel above to handle
340+
// when self has zero storage.
341+
// This kernel should never really be run, except with debugging using compile(backend="aot_eager")
342+
for (const auto i : c10::irange(src.size())) {
343+
auto curr_src = src[i];
344+
auto curr_self = self[i];
345+
outs.push_back(at::copy(curr_self, curr_src, non_blocking));
346+
}
347+
return outs;
348+
}
349+
315350
Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
316351
auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
317352
{

aten/src/ATen/native/native_functions.yaml

+8-1
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,7 @@
17501750
- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
17511751
variants: function
17521752
dispatch:
1753+
Meta: copy_meta
17531754
CompositeExplicitAutogradNonFunctional: copy
17541755
tags: core
17551756

@@ -11357,7 +11358,13 @@
1135711358
dispatch:
1135811359
CPU: foreach_tensor_copy_list_kernel_slow_
1135911360
CUDA: foreach_tensor_copy_list_kernel_cuda_
11360-
autogen: _foreach_copy, _foreach_copy.out
11361+
autogen: _foreach_copy.out
11362+
11363+
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
11364+
device_check: NoCheck
11365+
variants: function
11366+
dispatch:
11367+
CompositeExplicitAutograd: _foreach_copy
1136111368

1136211369
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
1136311370
dispatch:

test/dynamo/test_repros.py

+89
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,32 @@ def shapes_to_tensor(x, device=None):
160160
return torch.as_tensor(x, device=device)
161161

162162

163+
fw_graph = [None]
164+
bw_graph = [None]
165+
166+
167+
def aot_graph_capture_backend(gm, args):
168+
from functorch.compile import min_cut_rematerialization_partition
169+
from torch._functorch.aot_autograd import aot_module_simplified
170+
171+
def fw_compiler(gm, _):
172+
fw_graph[0] = gm
173+
return gm
174+
175+
def bw_compiler(gm, _):
176+
bw_graph[0] = gm
177+
return gm
178+
179+
return aot_module_simplified(
180+
gm,
181+
args,
182+
fw_compiler,
183+
bw_compiler,
184+
partition_fn=min_cut_rematerialization_partition,
185+
keep_inference_input_mutations=True,
186+
)
187+
188+
163189
class Boxes:
164190
# from detectron2 poolers.py
165191
def __init__(self, tensor: torch.Tensor):
@@ -4644,6 +4670,69 @@ def fn(instances):
46444670
self.assertEqual(type(actual), type(expected))
46454671
self.assertEqual(actual.__dict__, expected.__dict__)
46464672

4673+
def test_storage_resize_forward_full_graph(self):
4674+
class TestModule(torch.nn.Module):
4675+
def __init__(self):
4676+
super().__init__()
4677+
self.param = torch.nn.Parameter(torch.randn(4, 4))
4678+
4679+
def forward(self, x):
4680+
self.param.untyped_storage().resize_(
4681+
self.param.numel() * self.param.itemsize
4682+
)
4683+
with torch.no_grad():
4684+
torch._foreach_copy_([self.param], [x])
4685+
out = torch.matmul(self.param, self.param)
4686+
self.param.untyped_storage().resize_(0)
4687+
return out
4688+
4689+
def post_accumulate_grad_hook(param):
4690+
param.untyped_storage().resize_(0)
4691+
4692+
# Beginning of backward, resize and put data into the param
4693+
def pre_backward_hook(module, grad) -> None:
4694+
module.param.untyped_storage().resize_(
4695+
self.param.numel() * self.param.itemsize
4696+
)
4697+
with torch.no_grad():
4698+
# simulates loading data into param from allgather
4699+
module.param.fill_(2)
4700+
4701+
def post_forward_hook(module, args, output):
4702+
output.register_hook(functools.partial(pre_backward_hook, module))
4703+
4704+
x = torch.randn(4, 4)
4705+
4706+
mod_ref = TestModule()
4707+
mod_test = deepcopy(mod_ref)
4708+
4709+
# Start the param off with zero storage size to mimic fsdp
4710+
mod_ref.param.untyped_storage().resize_(0)
4711+
mod_test.param.untyped_storage().resize_(0)
4712+
4713+
# Resize storage at beginning of backward
4714+
# Free storage at end of backward
4715+
mod_ref.register_forward_hook(post_forward_hook, prepend=False)
4716+
mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4717+
mod_test.register_forward_hook(post_forward_hook, prepend=False)
4718+
mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4719+
4720+
mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)
4721+
4722+
out_ref = mod_ref(x)
4723+
out_test = mod_test(x)
4724+
self.assertExpectedInline(
4725+
str(fw_graph[0].code.strip()),
4726+
"""\
4727+
def forward(self, primals_1, primals_2):
4728+
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
4729+
getitem = _foreach_copy[0]; _foreach_copy = None
4730+
mm = torch.ops.aten.mm.default(getitem, getitem)
4731+
t_1 = torch.ops.aten.t.default(getitem); getitem = None
4732+
return [mm, t_1]""",
4733+
)
4734+
self.assertEqual(out_ref, out_test)
4735+
46474736
def test_super_in_staticmethod(self):
46484737
class A:
46494738
@staticmethod

test/dynamo_expected_failures/FakeTensorTest.test_embedding_bag_meta

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSC

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCSR

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cuda_bias_weightCOO

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cuda_bias_weightCSC

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cuda_bias_weightCSR

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cuda_nobias_weightCOO

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_swap_module_params_fails_after_forward

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_False_cpu

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrizationDeviceCPU.test_weight_norm_parametrization_swap_True_cpu

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_False_cuda

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrizationDeviceCUDA.test_weight_norm_parametrization_swap_True_cuda

Whitespace-only changes.

test/dynamo_expected_failures/TestNestedTensorDeviceTypeCPU.test_embedding_jagged_cpu

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_ConvTranspose2d_output_size_downsample_upsample

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNNDeviceTypeCPU.test_conv2d_no_grad_cpu_float32

Whitespace-only changes.

test/dynamo_skips/TestNNParametrization.test_new_spectral_norm_dim_swap_False

Whitespace-only changes.

test/dynamo_skips/TestVmapOperators.test_conv2d

Whitespace-only changes.

0 commit comments

Comments
 (0)