Skip to content

Commit 0b7d971

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Add support for nn.Parameter constructor (part 2) (pytorch#120965)
This handles the case where the tensor isn't an input. The changes to dynamo tests are cases where we would previously fall back to eager. Pull Request resolved: pytorch#120965 Approved by: https://github.com/yanboliang ghstack dependencies: pytorch#121735
1 parent 040b925 commit 0b7d971

File tree

99 files changed

+184
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+184
-2
lines changed

test/dynamo_expected_failures/TestEmbeddingNN.test_embeddingbag_include_last_offset

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Bilinear

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_discontiguous

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_max_padding_idx

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_mean_padding_idx

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sparse

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_EmbeddingBag_sum_padding_idx

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_discontiguous

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Embedding_sparse

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_Linear_no_batch_dim

Whitespace-only changes.

test/dynamo_expected_failures/TestJitGeneratedModule.test_nn_PReLU_no_batch_dim

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_ParameterList

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_bilinear_broadcasting

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_layer_norm_grads_with_create_graph_flag

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_bias_weightCOO

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_autograd_device_cpu_nobias_weightCOO

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_linear_broadcasting

Whitespace-only changes.

test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op

Whitespace-only changes.

test/dynamo_expected_failures/TestNNParametrization.test_errors_unparametrized_tensor_parametrization

Whitespace-only changes.

test/dynamo_expected_failures/TestPruningNN.test_identity_pruning

Whitespace-only changes.

test/dynamo_expected_failures/TestPruningNN.test_random_pruning_0perc

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_Conv1d_module_same_padding

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_Conv2d_backward_twice

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_Conv2d_module_same_padding

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_Conv3d_module_same_padding

Whitespace-only changes.

test/dynamo_skips/TestConvolutionNN.test_ConvTranspose3d_correct_output_size

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_circular_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_groups

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad1size1

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad2size1

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_same_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_pad_valid

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_reflect_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_replicate_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_stride

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zero_batch

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv1d_zeros_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_circular_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_padded

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_strided

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_depthwise_with_multiplier

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_groups_thnn

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_same_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_pad_valid

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_padding

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_reflect_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_replicate_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_strided

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zero_batch

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv2d_zeros_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_circular_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_dilated_strided

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_groups

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_same_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_pad_valid

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_replicate_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_stride_padding

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zero_batch

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_Conv3d_zeros_stride2_pad2

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_dilated

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose1d_groups

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose2d_groups

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d

Whitespace-only changes.

test/dynamo_skips/TestJitGeneratedModule.test_nn_ConvTranspose3d_dilated

Whitespace-only changes.

test/dynamo_skips/TestNN.test_padding_list

Whitespace-only changes.

test/dynamo_skips/TestNN.test_vector_to_parameters

Whitespace-only changes.

test/inductor/test_distributed_patterns.py

+30
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,36 @@ def fn(x):
301301
self._assert_same_grad(r1, r2)
302302
self._assert_same_grad(p1, p2)
303303

304+
def test_nn_param_return3(self):
305+
def fn(x):
306+
p = torch.nn.Parameter(x + 123)
307+
return p, p.sin()
308+
309+
opt = torch.compile(fn, fullgraph=True)
310+
x1 = torch.randn(16)
311+
x2 = x1.clone()
312+
313+
p1, r1 = fn(x1)
314+
r1.sum().backward()
315+
p2, r2 = opt(x2)
316+
r2.sum().backward()
317+
self._assert_same_grad(r1, r2)
318+
self._assert_same_grad(p1, p2)
319+
320+
def test_nn_param_return4(self):
321+
def fn(x):
322+
p = torch.nn.Parameter(x + 123, requires_grad=False)
323+
return p, x + 1
324+
325+
opt = torch.compile(fn, fullgraph=True)
326+
x1 = torch.randn(16)
327+
x2 = x1.clone()
328+
329+
p1, r1 = fn(x1)
330+
p2, r2 = opt(x2)
331+
self._assert_same_grad(r1, r2)
332+
self._assert_same_grad(p1, p2)
333+
304334

305335
if __name__ == "__main__":
306336
if HAS_CPU and not IS_MACOS:

torch/_dynamo/create_parameter_op.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
from torch._prims import _make_prim, RETURN_TYPE
3+
from torch._prims_common import clone_preserve_strides
4+
5+
doc = """
6+
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
7+
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
8+
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
9+
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
10+
to flow into the parameter as if it were an input to the graph (which is the only thing we are
11+
allowed to compute gradients on).
12+
""".strip()
13+
14+
_bind_nn_parameter = _make_prim(
15+
schema="_bind_nn_parameter(Tensor self, Tensor placeholder) -> Tensor",
16+
return_type=RETURN_TYPE.NEW,
17+
meta=lambda self, placeholder: torch.nn.Parameter(
18+
clone_preserve_strides(self), placeholder.requires_grad
19+
),
20+
impl_aten=lambda self, placeholder: placeholder.set_(self),
21+
doc=doc,
22+
)
23+
torch.fx.node.has_side_effect(_bind_nn_parameter)
24+
25+
26+
class TracableCreateParameter(torch.autograd.Function):
27+
@staticmethod
28+
def forward(ctx, tensor, placeholder):
29+
assert not tensor.requires_grad
30+
return _bind_nn_parameter(tensor, placeholder)
31+
32+
@staticmethod
33+
def backward(ctx, grad):
34+
return None, grad # grad flows to placeholder
35+
36+
37+
def tracable_create_parameter(tensor, placeholder):
38+
with torch.set_grad_enabled(placeholder.requires_grad):
39+
return TracableCreateParameter.apply(tensor, placeholder)
40+
41+
42+
def new_parameter_placeholder(size, dtype, device, requires_grad):
43+
"""Create a placeholder to be passed to the above functions"""
44+
result = torch.nn.Parameter(
45+
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
46+
)
47+
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
48+
# Allocating a zero tensor would causes assert failures in autograd.
49+
result.untyped_storage().resize_(0)
50+
return result

torch/_dynamo/output_graph.py

+23
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
LocalSource,
7070
ParamBufferSource,
7171
ShapeEnvSource,
72+
SyntheticLocalSource,
7273
TensorProperty,
7374
TensorPropertySource,
7475
)
@@ -472,6 +473,28 @@ def init_ambient_guards(self):
472473

473474
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
474475

476+
def synthetic_graph_input(self, fn, args):
477+
"""
478+
call fn(*args) before the graph runs and turn the result into a fake input.
479+
"""
480+
example_value = fn(*args)
481+
varname = self.new_var()
482+
cg = PyCodegen(self.root_tx)
483+
cg.load_import_from(
484+
fn.__module__,
485+
fn.__name__,
486+
)
487+
cg.foreach(map(variables.ConstantVariable.create, args))
488+
cg.call_function(len(args), True)
489+
cg.store(varname)
490+
self.pregraph_bytecode.extend(cg.get_instructions())
491+
source = SyntheticLocalSource(varname)
492+
result = VariableBuilder(self.root_tx, source)(example_value)
493+
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
494+
source
495+
)
496+
return result
497+
475498
def add_cleanup_hook(self, fn: Callable[[], Any]):
476499
self.cleanup_hooks.append(fn)
477500

torch/_dynamo/variables/torch.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..._guards import TracingContext
1818
from .. import config, polyfill, variables
1919
from ..codegen import PyCodegen
20+
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
2021
from ..device_interface import get_registered_device_interfaces
2122
from ..exc import unimplemented
2223
from ..guards import GuardBuilder, install_guard
@@ -840,7 +841,35 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True):
840841
if data.source:
841842
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)
842843

843-
unimplemented("Parameter() on non-input")
844+
try:
845+
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
846+
dtype = data.var_getattr(tx, "dtype").as_python_constant()
847+
device = data.var_getattr(tx, "device").as_python_constant()
848+
except NotImplementedError as e:
849+
unimplemented(f"Parameter not python_constant: {e}")
850+
851+
placeholder = tx.output.synthetic_graph_input(
852+
new_parameter_placeholder, [shape, dtype, device, requires_grad]
853+
)
854+
if data.requires_grad:
855+
data = data.call_method(tx, "detach", [], {})
856+
857+
from .builder import wrap_fx_proxy
858+
859+
result = wrap_fx_proxy(
860+
tx,
861+
tx.output.create_proxy(
862+
"call_function",
863+
tracable_create_parameter,
864+
(data.as_proxy(), placeholder.as_proxy()),
865+
{},
866+
),
867+
)
868+
assert isinstance(result, variables.TensorVariable)
869+
result.class_type = torch.nn.Parameter
870+
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
871+
result.source = placeholder.source
872+
return result
844873

845874
@staticmethod
846875
def _nn_param_via_prefix_insert(tx, data, requires_grad):

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,26 @@ def _output_node(gm: torch.fx.GraphModule) -> torch.fx.Node:
181181
return next(n for n in reversed(gm.graph.nodes) if n.op == "output")
182182

183183

184+
def _input_node(gm: torch.fx.GraphModule, i: int) -> torch.fx.Node:
185+
"""Fetch the i-th placeholder in the graph"""
186+
seen = 0
187+
for n in gm.graph.nodes:
188+
if n.op == "placeholder":
189+
if seen == i:
190+
return n
191+
seen += 1
192+
raise IndexError(f"input {i} does not exist, only {seen} inputs in graph")
193+
194+
195+
def _can_detach(node: torch.fx.Node):
196+
"""
197+
Avoid calling .detach() on inputs passed to _bind_nn_parameter()
198+
"""
199+
from torch._dynamo.create_parameter_op import _bind_nn_parameter
200+
201+
return all(n.target is not _bind_nn_parameter for n in node.users)
202+
203+
184204
def aot_dispatch_autograd(
185205
flat_fn,
186206
flat_args: List[Any],
@@ -317,7 +337,7 @@ def aot_dispatch_autograd(
317337
== len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
318338
)
319339
for i, (bw_out) in enumerate(bw_outs):
320-
if bw_out is None:
340+
if bw_out is None and _can_detach(_input_node(fx_g, i)):
321341
_indices_of_inps_to_detach.append(i)
322342

323343
if aot_config.enable_log:

torch/_inductor/ir.py

+23
Original file line numberDiff line numberDiff line change
@@ -4496,6 +4496,29 @@ def __init__(self, variable, new_size):
44964496
mark_node_as_mutating(self, variable)
44974497

44984498

4499+
class BindNNParameter(ExternKernelAlloc):
4500+
def __init__(self, variable, placeholder):
4501+
variable.freeze_layout()
4502+
super().__init__(
4503+
variable.get_layout(),
4504+
[variable, placeholder],
4505+
python_kernel_name="torch.ops.prims._bind_nn_parameter",
4506+
)
4507+
V.graph.never_reuse_buffers.add(variable.data.get_name())
4508+
V.graph.never_reuse_buffers.add(placeholder.get_name())
4509+
V.graph.never_reuse_buffers.add(self.get_name())
4510+
mark_node_as_mutating(self, variable, placeholder)
4511+
4512+
def get_alias_names(self):
4513+
return [self.inputs[0].get_name(), self.inputs[1].get_name()]
4514+
4515+
def get_mutation_names(self):
4516+
return [self.inputs[1].get_name()]
4517+
4518+
def has_side_effects(self):
4519+
return True
4520+
4521+
44994522
class ScatterFallback(ExternKernel):
45004523
"""
45014524
This needs to be a custom class to handle mutation properly.

torch/_inductor/lowering.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.ao.quantization.fx._decomposed
1414
import torch.fx
1515
import torch.utils._pytree as pytree
16+
from torch._dynamo.create_parameter_op import _bind_nn_parameter
1617
from torch._higher_order_ops.triton_kernel_wrap import (
1718
triton_kernel_wrapper_functional,
1819
triton_kernel_wrapper_mutation,
@@ -5924,6 +5925,12 @@ def resize_storage_bytes_(variable, new_size):
59245925
return variable
59255926

59265927

5928+
@register_lowering(_bind_nn_parameter)
5929+
def create_nn_parameter(self, placeholder):
5930+
self.realize()
5931+
return TensorBox.create(ir.BindNNParameter(self, placeholder))
5932+
5933+
59275934
from torch._higher_order_ops.auto_functionalize import auto_functionalized
59285935

59295936
make_fallback(auto_functionalized)

0 commit comments

Comments
 (0)