Skip to content

Commit a17cd22

Browse files
masnesralpytorchmergebot
authored andcommitted
[inductor] Enable FX graph caching on another round of inductor tests (pytorch#121994)
Summary: Enabling caching for these tests was blocked by pytorch#121686 Pull Request resolved: pytorch#121994 Approved by: https://github.com/eellison
1 parent 7c5e29a commit a17cd22

6 files changed

+12
-12
lines changed

test/inductor/test_compiled_optimizers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import torch._inductor.cudagraph_trees
1515
from torch._inductor import config
1616

17+
from torch._inductor.test_case import TestCase
18+
1719
from torch.optim import (
1820
Adadelta,
1921
Adagrad,
@@ -39,8 +41,6 @@
3941
optim_db,
4042
optims,
4143
)
42-
43-
from torch.testing._internal.common_utils import TestCase
4444
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA, has_triton
4545
from torch.testing._internal.triton_utils import requires_cuda
4646

@@ -598,7 +598,7 @@ def loop(opt, c):
598598
instantiate_device_type_tests(CompiledOptimizerParityTests, globals())
599599

600600
if __name__ == "__main__":
601-
from torch._dynamo.test_case import run_tests
601+
from torch._inductor.test_case import run_tests
602602

603603
if HAS_CPU or HAS_CUDA:
604604
run_tests(needs="filelock")

test/inductor/test_cutlass_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Callable, List
66

77
import torch
8-
from torch._dynamo.test_case import run_tests, TestCase
98
from torch._dynamo.utils import counters
109
from torch._inductor import config
10+
from torch._inductor.test_case import run_tests, TestCase
1111
from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, SM90OrLater
1212
from torch.testing._internal.common_utils import (
1313
instantiate_parametrized_tests,

test/inductor/test_group_batch_fusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import torch
77
import torch._inductor
8-
from torch._dynamo.test_case import run_tests, TestCase
98
from torch._dynamo.utils import counters, optimus_scuba_log
9+
from torch._inductor.test_case import run_tests, TestCase
1010
from torch.testing._internal.inductor_utils import HAS_CUDA
1111

1212
try:

test/inductor/test_torchinductor_dynamic_shapes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch._custom_ops as custom_ops
1212
import torch.library
1313
from torch._dynamo.testing import make_test_cls_with_patches
14+
from torch._inductor.test_case import TestCase
1415
from torch.testing._internal.common_device_type import (
1516
instantiate_device_type_tests,
1617
onlyCPU,
@@ -24,7 +25,6 @@
2425
TEST_CUDA_MEM_LEAK_CHECK,
2526
TEST_WITH_ASAN,
2627
TEST_WITH_ROCM,
27-
TestCaseBase as TestCase,
2828
)
2929
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
3030

@@ -613,7 +613,7 @@ def func(x, fn, a):
613613
instantiate_device_type_tests(TestInductorDynamic, globals())
614614

615615
if __name__ == "__main__":
616-
from torch._dynamo.test_case import run_tests
616+
from torch._inductor.test_case import run_tests
617617

618618
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
619619
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:

test/inductor/test_torchinductor_opinfo.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414

1515
from torch._dispatch.python import enable_python_dispatcher
16-
from torch._dynamo.test_case import run_tests
16+
from torch._inductor.test_case import run_tests, TestCase
1717
from torch._subclasses.fake_tensor import (
1818
DataDependentOutputException,
1919
DynamicOutputShapeException,
@@ -40,7 +40,6 @@
4040
TEST_MKL,
4141
TEST_WITH_ASAN,
4242
TEST_WITH_ROCM,
43-
TestCase,
4443
)
4544
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA
4645
from torch.utils._python_dispatch import TorchDispatchMode

test/inductor/test_unbacked_symints.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77
from torch._dynamo import config as dynamo_config
88
from torch._inductor import config as inductor_config
9+
from torch._inductor.test_case import TestCase as InductorTestCase
910
from torch._inductor.utils import is_big_gpu
1011
from torch.testing import make_tensor
1112
from torch.testing._internal.common_device_type import instantiate_device_type_tests
12-
from torch.testing._internal.common_utils import IS_LINUX, TestCase as TorchTestCase
13+
from torch.testing._internal.common_utils import IS_LINUX
1314
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, skipCUDAIf
1415

1516

16-
class TestUnbackedSymints(TorchTestCase):
17+
class TestUnbackedSymints(InductorTestCase):
1718
@skipCUDAIf(not HAS_CUDA, "requires cuda")
1819
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
1920
def test_expand(self, device):
@@ -147,7 +148,7 @@ def fn(x):
147148
)
148149

149150
if __name__ == "__main__":
150-
from torch._dynamo.test_case import run_tests
151+
from torch._inductor.test_case import run_tests
151152

152153
if IS_LINUX and HAS_CUDA and is_big_gpu(0):
153154
run_tests()

0 commit comments

Comments
 (0)