Skip to content

Commit 38bbe37

Browse files
malfetpytorchmergebot
authored andcommitted
Enable CI on SM89 (pytorch#140305)
Using EC2 G6 instance, based on NVIDIA L4, added to scale config in pytorch/test-infra#5376 To enable more balanced sharding, had to push pytorch/test-infra@148ae19 Added `@xfailIfSM89` to the following tests: - test_fp8_pattern_2 - test_original_aten_preserved_split_addmm - test_sparse_semi_structured_scaled_mm - test_sparse_semi_structured_scaled_mm_fp8 - test_sparse_fp8fp8_mm Increased tolerance to 2e-4 for `RNNTest.BidirectionalMultilayerGRU_CPU_vs_CUDA` Skipped following inductor tests (that either flaky OOMs or timeouts): - test_reduction_fn_std_float64 - test_reduction_fn_var_mean_float64 - test_multi_output_unbacked_custom_op Pull Request resolved: pytorch#140305 Approved by: https://github.com/wdvr, https://github.com/ZainRizvi
1 parent af88326 commit 38bbe37

9 files changed

+74
-21
lines changed

.github/workflows/pull.yml

+15-15
Original file line numberDiff line numberDiff line change
@@ -476,35 +476,35 @@ jobs:
476476
]}
477477
secrets: inherit
478478

479-
linux-focal-cuda12_4-py3_10-gcc9-sm86-build:
480-
name: linux-focal-cuda12.4-py3.10-gcc9-sm86
479+
linux-focal-cuda12_4-py3_10-gcc9-sm89-build:
480+
name: linux-focal-cuda12.4-py3.10-gcc9-sm89
481481
uses: ./.github/workflows/_linux-build.yml
482482
needs: get-label-type
483483
with:
484484
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
485-
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
485+
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm89
486486
docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9
487-
cuda-arch-list: 8.6
487+
cuda-arch-list: 8.9
488488
test-matrix: |
489489
{ include: [
490-
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
491-
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
492-
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
493-
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
494-
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
490+
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
491+
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
492+
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
493+
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
494+
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
495495
]}
496496
secrets: inherit
497497

498-
linux-focal-cuda12_4-py3_10-gcc9-sm86-test:
499-
name: linux-focal-cuda12.4-py3.10-gcc9-sm86
498+
linux-focal-cuda12_4-py3_10-gcc9-sm89-test:
499+
name: linux-focal-cuda12.4-py3.10-gcc9-sm89
500500
uses: ./.github/workflows/_linux-test.yml
501501
needs:
502-
- linux-focal-cuda12_4-py3_10-gcc9-sm86-build
502+
- linux-focal-cuda12_4-py3_10-gcc9-sm89-build
503503
- target-determination
504504
with:
505-
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86
506-
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.docker-image }}
507-
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm86-build.outputs.test-matrix }}
505+
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm89
506+
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm89-build.outputs.docker-image }}
507+
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-sm89-build.outputs.test-matrix }}
508508
secrets: inherit
509509

510510
linux-jammy-py3-clang12-executorch-build:

test/cpp/api/rnn.cpp

+33-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include <torch/torch.h>
44

55
#include <test/cpp/api/support.h>
6+
#ifdef USE_CUDA
7+
#include <ATen/cuda/CUDAContext.h>
8+
#endif
69

710
using namespace torch::nn;
811
using namespace torch::test;
@@ -552,6 +555,15 @@ TEST_F(RNNTest, BidirectionalLSTMReverseForward_CUDA) {
552555
}
553556

554557
TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
558+
#ifdef USE_CUDA
559+
// Get device properties
560+
const auto prop = at::cuda::getCurrentDeviceProperties();
561+
// TODO: Investigate why results on sm89 are much less accurate
562+
// See https://github.com/pytorch/pytorch/issues/141915
563+
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
564+
#else
565+
constexpr auto tolerance = 1e-5;
566+
#endif
555567
// Create two GRUs with the same options
556568
auto opt =
557569
GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
@@ -600,13 +612,22 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
600612
ASSERT_NEAR(
601613
std::get<0>(output_cpu)[i][j][k].item<float>(),
602614
std::get<0>(output_cuda)[i][j][k].item<float>(),
603-
1e-5);
615+
tolerance);
604616
}
605617
}
606618
}
607619
}
608620

609621
TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
622+
#ifdef USE_CUDA
623+
// Get device properties
624+
const auto prop = at::cuda::getCurrentDeviceProperties();
625+
// TODO: Investigate why results on sm89 are much less accurate
626+
// See https://github.com/pytorch/pytorch/issues/141915
627+
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
628+
#else
629+
constexpr auto tolerance = 1e-5;
630+
#endif
610631
// Create two LSTMs with the same options
611632
auto opt =
612633
LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
@@ -654,13 +675,22 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
654675
ASSERT_NEAR(
655676
std::get<0>(output_cpu)[i][j][k].item<float>(),
656677
std::get<0>(output_cuda)[i][j][k].item<float>(),
657-
1e-5);
678+
tolerance);
658679
}
659680
}
660681
}
661682
}
662683

663684
TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
685+
#ifdef USE_CUDA
686+
// Get device properties
687+
const auto prop = at::cuda::getCurrentDeviceProperties();
688+
// TODO: Investigate why results on sm89 are much less accurate
689+
// See https://github.com/pytorch/pytorch/issues/141915
690+
const auto tolerance = prop->major == 8 && prop->minor == 9 ? 2e-4 : 1e-5;
691+
#else
692+
constexpr auto tolerance = 1e-5;
693+
#endif
664694
// Create two LSTMs with the same options
665695
auto opt = LSTMOptions(2, 4)
666696
.num_layers(3)
@@ -711,7 +741,7 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTMProj_CPU_vs_CUDA) {
711741
ASSERT_NEAR(
712742
std::get<0>(output_cpu)[i][j][k].item<float>(),
713743
std::get<0>(output_cuda)[i][j][k].item<float>(),
714-
1e-5);
744+
tolerance);
715745
}
716746
}
717747
}

test/inductor/test_cooperative_reductions.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Owner(s): ["module: inductor"]
2+
import unittest
23
from typing import Any, Dict, List, Type
34

45
import sympy
@@ -11,6 +12,7 @@
1112
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
1213
from torch._inductor.test_case import TestCase
1314
from torch._inductor.utils import run_and_get_code
15+
from torch.testing._internal.common_cuda import IS_SM89
1416
from torch.testing._internal.common_utils import (
1517
instantiate_parametrized_tests,
1618
parametrize,
@@ -60,6 +62,9 @@ def run_and_check(self, fn, args, *, expect_kernel_count=1):
6062
)
6163
@parametrize("dtype", [torch.float16, torch.float32, torch.float64])
6264
def test_reduction_fns(self, name, dtype):
65+
if IS_SM89 and dtype == torch.float64 and name in ["std", "var_mean"]:
66+
raise unittest.SkipTest("Timeouts on SM89")
67+
6368
def fn(x, y):
6469
return reduction_fn(x + y, dim=-1)
6570

test/inductor/test_kernel_benchmark.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch._inductor.test_case import run_tests, TestCase
1414
from torch._inductor.utils import fresh_inductor_cache
1515
from torch.testing import FileCheck
16+
from torch.testing._internal.common_cuda import xfailIfSM89
1617
from torch.testing._internal.common_device_type import expectedFailureXPU
1718
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1819

@@ -384,6 +385,7 @@ def f(a, b, c):
384385
self.check_bandwidth(compiled_module, "0.006")
385386

386387
@expectedFailureXPU
388+
@xfailIfSM89
387389
@config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
388390
def test_slice_mm_bandwidth_computation(self):
389391
M, N, K = 1000, 2000, 3000

test/inductor/test_loop_ordering.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch._inductor.test_operators import realize
1919
from torch._inductor.utils import sympy_index_symbol
2020
from torch._inductor.virtualized import ops, V
21-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
21+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, xfailIfSM89
2222
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
2323
from torch.utils._pytree import tree_map
2424
from torch.utils._sympy.functions import ModularIndexing
@@ -406,6 +406,7 @@ def f(x, scale):
406406
self.assertEqual(1, metrics.generated_kernel_count)
407407

408408
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
409+
@xfailIfSM89
409410
def test_fp8_pattern_2(self):
410411
"""
411412
This test repros the fp8 fusion relation issue here:

test/inductor/test_pattern_matcher.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch._inductor.virtualized import V
3434
from torch.fx.experimental.proxy_tensor import make_fx
3535
from torch.testing import FileCheck
36-
from torch.testing._internal.common_cuda import SM80OrLater
36+
from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89
3737
from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf
3838
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu
3939
from torch.testing._internal.inductor_utils import (
@@ -1309,6 +1309,7 @@ def remap_fake_tensor(x):
13091309
self.assertTrue(pattern.pattern_eq(search_fn_pattern))
13101310

13111311
@skipIfXpu
1312+
@xfailIfSM89
13121313
@inductor_config.patch(
13131314
{
13141315
"triton.unique_kernel_names": "original_aten",

test/inductor/test_torchinductor_dynamic_shapes.py

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch._inductor.utils import run_and_get_code
2121
from torch._inductor.virtualized import V
2222
from torch.testing import FileCheck
23+
from torch.testing._internal.common_cuda import IS_SM89
2324
from torch.testing._internal.common_device_type import (
2425
instantiate_device_type_tests,
2526
onlyCPU,
@@ -575,6 +576,10 @@ def f(x):
575576

576577
f(torch.tensor([3], device=device))
577578

579+
@unittest.skipIf(
580+
IS_SM89,
581+
"Fails(with OOMS) on SM89, see https://github.com/pytorch/pytorch/issues/141915",
582+
)
578583
@torch._dynamo.config.patch(
579584
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
580585
)

test/test_sparse_semi_structured.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from torch.testing import make_tensor
24-
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8
24+
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, xfailIfSM89
2525
from torch.testing._internal.common_device_type import (
2626
dtypes,
2727
instantiate_device_type_tests,
@@ -1047,6 +1047,7 @@ def setUp(self):
10471047
self.skipTest('cuSPARSELt not enabled')
10481048

10491049
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
1050+
@xfailIfSM89
10501051
@parametrize("dense_input_shape", [(256, 128)])
10511052
def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
10521053
if torch.backends.cusparselt.version() < 602:
@@ -1066,6 +1067,7 @@ def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
10661067
dense_result = torch.mm(A_fp8_sparse, B_fp8)
10671068

10681069
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
1070+
@xfailIfSM89
10691071
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
10701072
(k, l, m) = (32, 64, 32)
10711073
x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device)
@@ -1082,6 +1084,7 @@ def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
10821084
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
10831085

10841086
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices")
1087+
@xfailIfSM89
10851088
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
10861089
@parametrize("dense_input_shape", [(256, 128)])
10871090
def test_sparse_semi_structured_scaled_mm(

torch/testing/_internal/common_cuda.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import inspect
1010
import contextlib
1111
import os
12+
import unittest
1213

1314

1415
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
@@ -33,6 +34,7 @@
3334
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
3435

3536
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
37+
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
3638

3739
def CDNA2OrLater():
3840
if TEST_WITH_ROCM:
@@ -316,6 +318,10 @@ def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.
316318
) + (data, loss_fn, skip_iter)
317319

318320

321+
def xfailIfSM89(func):
322+
return func if not IS_SM89 else unittest.expectedFailure(func)
323+
324+
319325
# Importing this module should NOT eagerly initialize CUDA
320326
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
321327
assert not torch.cuda.is_initialized()

0 commit comments

Comments
 (0)