Skip to content

Commit 4742080

Browse files
etafpytorchmergebot
authored andcommitted
1 parent c418a9a commit 4742080

19 files changed

+432
-49
lines changed

.lintrunner.toml

+1
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ exclude_patterns = [
263263
'torch/csrc/jit/**/*',
264264
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
265265
'torch/csrc/utils/pythoncapi_compat.h',
266+
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
266267
]
267268
init_command = [
268269
'python3',

build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ libtorch_python_xpu_sources = [
792792
"torch/csrc/xpu/Event.cpp",
793793
"torch/csrc/xpu/Module.cpp",
794794
"torch/csrc/xpu/Stream.cpp",
795+
"torch/csrc/inductor/aoti_torch/shim_xpu.cpp",
795796
]
796797

797798
libtorch_python_core_sources = [

caffe2/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,10 @@ if(USE_XPU)
11161116

11171117
# Set cached ${ATen_XPU_INCLUDE_DIRS} to torch
11181118
include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS})
1119+
message(INFO "Install ${TORCH_XPU_OPS_DIR}/src/ATen/xpu to ${TORCH_INSTALL_INCLUDE_DIR}/ATen/xpu")
1120+
install(DIRECTORY "${TORCH_XPU_OPS_DIR}/src/ATen/xpu"
1121+
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/ATen/
1122+
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
11191123

11201124
endif()
11211125
endif()

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,7 @@ def main():
12891289
"include/torch/csrc/inductor/aoti_torch/*.h",
12901290
"include/torch/csrc/inductor/aoti_torch/c/*.h",
12911291
"include/torch/csrc/inductor/aoti_torch/generated/*.h",
1292+
"include/torch/csrc/inductor/aoti_torch/generated/extend/*.h",
12921293
"include/torch/csrc/jit/*.h",
12931294
"include/torch/csrc/jit/backends/*.h",
12941295
"include/torch/csrc/jit/generated/*.h",

test/inductor/test_cuda_cpp_wrapper.py

+57-30
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import torch
88
from torch._inductor import config
99
from torch._inductor.test_case import TestCase as InductorTestCase
10+
from torch._inductor.utils import is_gpu
1011
from torch.testing._internal.common_device_type import (
1112
get_desired_device_type_test_bases,
1213
)
1314
from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN
14-
from torch.testing._internal.inductor_utils import HAS_CUDA
15+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1516

1617

1718
try:
@@ -38,29 +39,40 @@
3839
raise
3940

4041

41-
_desired_test_bases = get_desired_device_type_test_bases()
42-
RUN_CUDA = (
43-
HAS_CUDA
44-
and any(getattr(x, "device_type", "") == "cuda" for x in _desired_test_bases)
42+
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
43+
RUN_GPU = (
44+
HAS_GPU
45+
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
4546
and not TEST_WITH_ASAN
4647
)
4748

4849

49-
class CudaWrapperTemplate:
50+
class GpuWrapperTemplate:
5051
pass
5152

5253

53-
class TestCudaWrapper(InductorTestCase):
54-
device = "cuda"
54+
class TestGpuWrapper(InductorTestCase):
55+
device = GPU_TYPE
5556

5657

57-
class DynamicShapesCudaWrapperCudaTests(InductorTestCase):
58-
device = "cuda"
58+
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
59+
device = GPU_TYPE
5960

6061

61-
test_failures_cuda_wrapper = {
62+
test_failures_gpu_wrapper = {
6263
"test_mm_plus_mm2_cuda_dynamic_shapes": test_torchinductor.TestFailure(
63-
("cuda_wrapper",), is_skip=True
64+
("gpu_wrapper",), is_skip=True
65+
),
66+
"test_randint_xpu": test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=False),
67+
"test_randint_xpu_dynamic_shapes": test_torchinductor.TestFailure(
68+
("gpu_wrapper",), is_skip=False
69+
),
70+
# ATen ops: scaled_dot_product_efficient_attention not implemented on XPU.
71+
"test_scaled_dot_product_efficient_attention_xpu": test_torchinductor.TestFailure(
72+
("gpu_wrapper",), is_skip=False
73+
),
74+
"test_scaled_dot_product_efficient_attention_xpu_dynamic_shapes": test_torchinductor.TestFailure(
75+
("gpu_wrapper",), is_skip=False
6476
),
6577
}
6678

@@ -114,20 +126,34 @@ def fn(self):
114126
fn.__dict__ = copy.deepcopy(func.__dict__)
115127
if condition:
116128
setattr(
117-
CudaWrapperTemplate,
129+
GpuWrapperTemplate,
118130
test_name,
119131
fn,
120132
)
121133

122134

123-
if RUN_CUDA:
135+
if RUN_GPU:
124136

125137
class BaseTest(NamedTuple):
126138
name: str
127-
device: str = "cuda"
139+
device: str = GPU_TYPE
128140
tests: InductorTestCase = test_torchinductor.GPUTests()
129141
check_code: bool = True
130142

143+
# XPU Not implemented yet
144+
XPU_BASE_TEST_SKIP = [
145+
"test_foreach_cpp_wrapper",
146+
"test_enable_dynamic_shapes_cpp_wrapper",
147+
"test_dynamic_shapes_persistent_reduction_mixed_x_dim",
148+
"test_cat_slice_cat",
149+
"test_mm_plus_mm2",
150+
"test_mm_plus_mm3",
151+
"test_addmm",
152+
"test_linear_relu",
153+
"test_fft_real_input",
154+
"test_fft_real_input_real_output",
155+
]
156+
131157
# Maintain two separate test lists for cuda and cpp for now
132158
for item in [
133159
BaseTest("test_add_complex"),
@@ -236,40 +262,41 @@ class BaseTest(NamedTuple):
236262
tests=test_select_algorithm.TestSelectAlgorithm(),
237263
),
238264
]:
265+
if item.device == "xpu" and item.name in XPU_BASE_TEST_SKIP:
266+
continue
239267
make_test_case(item.name, item.device, item.tests, check_code=item.check_code)
240268

241269
from torch._inductor.utils import is_big_gpu
242270

243-
if is_big_gpu(0):
271+
if GPU_TYPE == "cuda" and is_big_gpu(0):
244272
skip_list = ["test_addmm", "test_linear_relu"]
245273
# need to skip instead of omit, otherwise fbcode ci can be flaky
246274
for test_name in skip_list:
247-
test_failures_cuda_wrapper[
275+
test_failures_gpu_wrapper[
248276
f"{test_name}_cuda"
249-
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
250-
test_failures_cuda_wrapper[
251-
f"{test_name}_cuda_dynamic_shapes"
252-
] = test_torchinductor.TestFailure(("cuda_wrapper",), is_skip=True)
277+
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
278+
test_failures_gpu_wrapper[
279+
f"{test_name}_gpu_dynamic_shapes"
280+
] = test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True)
253281

254282
test_torchinductor.copy_tests(
255-
CudaWrapperTemplate, TestCudaWrapper, "cuda_wrapper", test_failures_cuda_wrapper
283+
GpuWrapperTemplate, TestGpuWrapper, "gpu_wrapper", test_failures_gpu_wrapper
256284
)
257285

258-
DynamicShapesCudaWrapperTemplate = (
259-
test_torchinductor_dynamic_shapes.make_dynamic_cls(CudaWrapperTemplate)
286+
DynamicShapesGpuWrapperTemplate = (
287+
test_torchinductor_dynamic_shapes.make_dynamic_cls(GpuWrapperTemplate)
260288
)
261289

262290
test_torchinductor.copy_tests(
263-
DynamicShapesCudaWrapperTemplate,
264-
DynamicShapesCudaWrapperCudaTests,
265-
"cuda_wrapper",
266-
test_failures_cuda_wrapper,
291+
DynamicShapesGpuWrapperTemplate,
292+
DynamicShapesGpuWrapperGpuTests,
293+
"gpu_wrapper",
294+
test_failures_gpu_wrapper,
267295
xfail_prop="_expected_failure_dynamic_wrapper",
268296
)
269297

270298
if __name__ == "__main__":
271299
from torch._inductor.test_case import run_tests
272300

273-
print(f"FS: run_cuda {RUN_CUDA}")
274-
if RUN_CUDA:
301+
if RUN_GPU:
275302
run_tests(needs="filelock")

test/inductor/test_memory_planning.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sys
44
import unittest
55

6-
from torch.testing._internal.common_device_type import expectedFailureXPU
76
from torch.testing._internal.common_utils import (
87
IS_CI,
98
IS_WINDOWS,
@@ -71,7 +70,6 @@ def test_python_wrapper(self):
7170
)
7271
self.assertTrue(same(f(*args), result))
7372

74-
@expectedFailureXPU
7573
def test_cpp_wrapper(self):
7674
f, args = self._generate(device=GPU_TYPE)
7775
compiled = torch.compile(f, dynamic=True)

test/inductor/test_triton_kernels.py

-1
Original file line numberDiff line numberDiff line change
@@ -3265,7 +3265,6 @@ def f(x, y):
32653265
gm = make_fx(f, tracing_mode=tracing_mode)(x, x)
32663266
self.assertEqual(gm(x, x), x + x)
32673267

3268-
@skipIfXpu
32693268
@requires_gpu
32703269
@patch.object(torch._inductor.config, "cpp_wrapper", True)
32713270
@patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True)

torch/_inductor/codegen/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def init_backend_registration():
382382
"xpu",
383383
TritonScheduling,
384384
PythonWrapperCodegen,
385+
CppWrapperGpu,
385386
)
386387

387388
private_backend = torch._C._get_privateuse1_backend_name()

torch/_inductor/codegen/cpp_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
DEVICE_TO_ATEN = {
8282
"cpu": "at::kCPU",
8383
"cuda": "at::kCUDA",
84+
"xpu": "at::kXPU",
8485
}
8586

8687
LAYOUT_TO_ATEN = {

torch/_inductor/codegen/cpp_wrapper_cpu.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,16 @@ class RAIIPyObject {
198198
}}
199199
"""
200200
)
201-
extend_aoti_path = (
201+
extend_aoti_c_shim_include = (
202202
f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h"
203203
)
204-
if os.path.exists(extend_aoti_path):
205-
self.header.splice(f"#include <{extend_aoti_path}>")
204+
extend_aoti_c_shim_path = os.path.join(
205+
os.path.dirname(torch.__file__),
206+
"include",
207+
extend_aoti_c_shim_include,
208+
)
209+
if os.path.exists(extend_aoti_c_shim_path):
210+
self.header.splice(f"#include <{extend_aoti_c_shim_include}>")
206211

207212
enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [
208213
"linux",

torch/_inductor/codegen/wrapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ def write_kernel_autotune_defs_header(self) -> None:
782782
async_compile = AsyncCompile()
783783
generate_example_value = AlgorithmSelectorCache.generate_example_value
784784
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
785+
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
785786
"""
786787
)
787788

torch/_inductor/codegen/xpu/device_op_overrides.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def device_guard(self, device_idx):
1616
return f"torch.xpu._DeviceGuard({device_idx})"
1717

1818
def cpp_device_guard(self):
19-
return "at::xpu::XPUGuard"
19+
return "at::DeviceGuard"
2020

2121
def cpp_aoti_device_guard(self):
2222
return "AOTIXpuGuard"
@@ -30,5 +30,50 @@ def cpp_aoti_stream_guard(self):
3030
def cpp_getStreamFromExternal(self):
3131
return "at::xpu::getStreamFromExternal"
3232

33+
def kernel_header(self):
34+
source_codes = """
35+
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
36+
"""
37+
return source_codes
38+
39+
def kernel_driver(self):
40+
source_codes = """
41+
namespace {
42+
43+
struct Grid {
44+
Grid(uint32_t x, uint32_t y, uint32_t z)
45+
: grid_x(x), grid_y(y), grid_z(z) {}
46+
uint32_t grid_x;
47+
uint32_t grid_y;
48+
uint32_t grid_z;
49+
50+
bool is_non_zero() {
51+
return grid_x > 0 && grid_y > 0 && grid_z > 0;
52+
}
53+
};
54+
55+
} // anonymous namespace
56+
57+
"""
58+
return source_codes
59+
60+
def abi_compatible_header(self):
61+
return """
62+
#include <torch/csrc/inductor/aoti_runtime/utils_xpu.h>
63+
#include <torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h>
64+
"""
65+
66+
def cpp_stream_type(self):
67+
return "sycl::queue*"
68+
69+
def aoti_get_stream(self):
70+
return "aoti_torch_get_current_xpu_stream"
71+
72+
def cpp_kernel_type(self):
73+
return "std::unique_ptr<sycl::kernel>"
74+
75+
def cpp_device_ptr(self):
76+
return "void *"
77+
3378

3479
register_device_op_overrides("xpu", XPUDeviceOpOverrides())

0 commit comments

Comments
 (0)