|
| 1 | +# Owner(s): ["module: inductor"] |
| 2 | +import copy |
| 3 | +import os |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | +from torch._dynamo.test_case import run_tests, TestCase |
| 8 | +from torch._dynamo.utils import same |
| 9 | +from torch.testing._internal.common_utils import TEST_WITH_ROCM |
| 10 | +from torch.testing._internal.inductor_utils import HAS_CUDA |
| 11 | + |
| 12 | +USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1" |
| 13 | + |
| 14 | + |
| 15 | +class Model2Conv(nn.Module): |
| 16 | + def __init__(self, dim=512, manual_graph_break=False): |
| 17 | + super().__init__() |
| 18 | + self.conv1 = nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False) |
| 19 | + self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False) |
| 20 | + self.manual_graph_break = manual_graph_break |
| 21 | + |
| 22 | + def forward(self, x): |
| 23 | + x = self.conv1(x) |
| 24 | + if self.manual_graph_break: |
| 25 | + torch._dynamo.graph_break() |
| 26 | + x = self.conv2(x) |
| 27 | + return x |
| 28 | + |
| 29 | + def get_example_inputs(self): |
| 30 | + return (torch.rand(2, 3, 16, 16),) |
| 31 | + |
| 32 | + |
| 33 | +class TestLayoutOptim(TestCase): |
| 34 | + @classmethod |
| 35 | + def setUpClass(cls): |
| 36 | + super().setUpClass() |
| 37 | + |
| 38 | + import torch.distributed as dist |
| 39 | + |
| 40 | + port = 10001 |
| 41 | + dist.init_process_group( |
| 42 | + backend="nccl", init_method=f"tcp://localhost:{port}", world_size=1, rank=0 |
| 43 | + ) |
| 44 | + |
| 45 | + def verify_accuracy( |
| 46 | + self, model_class, use_ddp_wrapper=USE_DDP_WRAPPER, is_train=False |
| 47 | + ): |
| 48 | + # there are 2 potential ways to introduce graph breaks |
| 49 | + # 1. manually |
| 50 | + # 2. using DDP |
| 51 | + # if we are not using DDP to introduce graph breaks, do that manually |
| 52 | + def wrap_mod(m): |
| 53 | + if is_train: |
| 54 | + |
| 55 | + def f(*inp): |
| 56 | + x = m(*inp) |
| 57 | + x.sum().backward() |
| 58 | + |
| 59 | + grads = [] |
| 60 | + for name, param in m.named_parameters(): |
| 61 | + grad = param.grad |
| 62 | + if param.grad is None: |
| 63 | + grad = torch.zeros_like(param) |
| 64 | + grads.append(grad) |
| 65 | + return grads |
| 66 | + |
| 67 | + return f |
| 68 | + else: |
| 69 | + return m |
| 70 | + |
| 71 | + manual_graph_break = not use_ddp_wrapper |
| 72 | + mod = model_class(manual_graph_break=manual_graph_break).cuda() |
| 73 | + inp = [t.cuda() for t in mod.get_example_inputs()] |
| 74 | + expected_out = wrap_mod(mod)(*inp) |
| 75 | + |
| 76 | + fp64_mod = copy.deepcopy(mod).to(torch.float64) |
| 77 | + fp64_inp = [t.to(torch.float64) for t in copy.deepcopy(inp)] |
| 78 | + fp64_out = wrap_mod(fp64_mod)(*fp64_inp) |
| 79 | + |
| 80 | + if use_ddp_wrapper: |
| 81 | + from torch.nn.parallel import DistributedDataParallel as DDP |
| 82 | + |
| 83 | + ddp_wrapped_mod = DDP(mod) |
| 84 | + opt_mod = torch.compile(wrap_mod(ddp_wrapped_mod)) |
| 85 | + else: |
| 86 | + opt_mod = torch.compile(wrap_mod(mod)) |
| 87 | + actual_out = opt_mod(*inp) |
| 88 | + |
| 89 | + if is_train: |
| 90 | + self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out)) |
| 91 | + else: |
| 92 | + expected_sum = expected_out.sum() |
| 93 | + actual_sum = actual_out.sum() |
| 94 | + print(f"Expected sum {expected_sum}, actual sum {actual_sum}") |
| 95 | + self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out)) |
| 96 | + |
| 97 | + def verify_accuracy_for_infer(self, *args, **kwargs): |
| 98 | + self.verify_accuracy(*args, **kwargs, is_train=False) |
| 99 | + |
| 100 | + def verify_accuracy_for_train(self, *args, **kwargs): |
| 101 | + self.verify_accuracy(*args, **kwargs, is_train=True) |
| 102 | + |
| 103 | + def test_2conv_with_graph_break(self): |
| 104 | + """ |
| 105 | + Make sure graph break does not cause any accuracy issue. |
| 106 | + """ |
| 107 | + self.verify_accuracy_for_infer(Model2Conv) |
| 108 | + |
| 109 | + def test_3conv_with_graph_break(self): |
| 110 | + class Model(nn.Module): |
| 111 | + def __init__( |
| 112 | + self, dim=512, patch_size=7, kernel_size=7, manual_graph_break=False |
| 113 | + ): |
| 114 | + super().__init__() |
| 115 | + self.seq = nn.Sequential( |
| 116 | + nn.Conv2d( |
| 117 | + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False |
| 118 | + ), |
| 119 | + nn.Conv2d( |
| 120 | + dim, dim, kernel_size, groups=dim, padding="same", bias=False |
| 121 | + ), |
| 122 | + ) |
| 123 | + self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=False) |
| 124 | + self.manual_graph_break = manual_graph_break |
| 125 | + |
| 126 | + def forward(self, x): |
| 127 | + x = self.seq(x) |
| 128 | + if self.manual_graph_break: |
| 129 | + torch._dynamo.graph_break() |
| 130 | + x = self.conv(x) |
| 131 | + return x |
| 132 | + |
| 133 | + def get_example_inputs(self): |
| 134 | + return (torch.randn(2, 3, 16, 16),) |
| 135 | + |
| 136 | + self.verify_accuracy_for_infer(Model) |
| 137 | + |
| 138 | + def test_keep_output_layout_infer(self): |
| 139 | + class Model(nn.Module): |
| 140 | + def __init__(self): |
| 141 | + super().__init__() |
| 142 | + self.conv = nn.Conv2d( |
| 143 | + 3, 128, kernel_size=3, padding=1, stride=1, bias=False |
| 144 | + ) |
| 145 | + |
| 146 | + def forward(self, x): |
| 147 | + x = self.conv(x) |
| 148 | + return x |
| 149 | + |
| 150 | + def get_example_inputs(self): |
| 151 | + return (torch.randn(2, 3, 5, 5),) |
| 152 | + |
| 153 | + mod = Model().cuda() |
| 154 | + inp = [t.cuda() for t in mod.get_example_inputs()] |
| 155 | + out = mod(*inp) |
| 156 | + |
| 157 | + opt_mod = torch.compile(mod) |
| 158 | + opt_out = opt_mod(*inp) |
| 159 | + |
| 160 | + # We should be able to do view on eager output |
| 161 | + out.view(5, -1) |
| 162 | + |
| 163 | + # We should be able to do view on the output of the optimized module |
| 164 | + # Note that if the output is channels last, the view op will fail. |
| 165 | + opt_out.view(5, -1) |
| 166 | + |
| 167 | + def test_training_acc(self): |
| 168 | + self.verify_accuracy_for_train(Model2Conv) |
| 169 | + |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + if HAS_CUDA and not TEST_WITH_ROCM: |
| 173 | + run_tests() |
0 commit comments