Skip to content

Commit e51e971

Browse files
xuhancnpytorchmergebot
authored andcommitted
[inductor] adapte windows file path (pytorch#130713)
This PR is depends on pytorch#130132 can be landed successful. The detailed log: pytorch#124245 (comment) After the file path was adapted for Windows, the first Windows inductor case was run successful. ```python import torch def foo(x, y): a = torch.sin(x) b = torch.cos(x) return a + b opt_foo1 = torch.compile(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) ``` Result: ![image](https://github.com/user-attachments/assets/4944df47-e74d-476b-8eb5-1d1fd5abeb41) Co-authored-by: Jiong Gong <jiong.gong@intel.com> Pull Request resolved: pytorch#130713 Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/desertfire
1 parent 7c45476 commit e51e971

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

test/functorch/test_eager_transforms.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5085,9 +5085,8 @@ def wrapper(*args, **kwargs):
50855085
@markDynamoStrictTest
50865086
class TestCompileTransforms(TestCase):
50875087
@skipIfRocm(msg="test leaks memory on ROCm")
5088-
# torch.compile is not supported on Windows
50895088
# Triton only supports GPU with SM70 or later.
5090-
@expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater))
5089+
@expectedFailureIf(TEST_CUDA and not SM70OrLater)
50915090
def test_compile_vmap_hessian(self, device):
50925091
# The model and inputs are a smaller version
50935092
# of code at benchmark repo:

torch/_inductor/codecache.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,26 @@
5757
rocm_compile_command,
5858
rocm_compiler,
5959
)
60-
from .cpp_builder import (
61-
_get_python_include_dirs,
62-
get_cpp_compiler,
63-
homebrew_libomp,
64-
is_apple_clang,
65-
is_clang,
66-
is_conda_llvm_openmp_installed,
67-
)
6860

6961
"""
7062
codecache.py, cpp_builder.py and cpu_vec_isa.py import rule:
7163
https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902
7264
"""
7365
from torch._inductor.cpp_builder import (
66+
_get_python_include_dirs,
7467
_set_gpu_runtime_env,
7568
_transform_cuda_paths,
7669
CppBuilder,
7770
CppOptions,
7871
CppTorchCudaOptions,
7972
get_compiler_version_info,
73+
get_cpp_compiler,
8074
get_name_and_dir_from_output_file_path,
75+
homebrew_libomp,
76+
is_apple_clang,
77+
is_clang,
78+
is_conda_llvm_openmp_installed,
79+
normalize_path_separator,
8180
)
8281
from torch._inductor.cpu_vec_isa import invalid_vec_isa, pick_vec_isa, VecISA
8382
from torch._inductor.runtime.compile_tasks import (
@@ -1922,7 +1921,7 @@ def cpp_prefix_path() -> str:
19221921
content,
19231922
"h",
19241923
)
1925-
return filename
1924+
return normalize_path_separator(filename)
19261925

19271926

19281927
def cpp_prefix() -> str:
@@ -2109,7 +2108,7 @@ def load_async(cls, source_code: str, cuda=False, submit_fn=None, extra_flags=()
21092108
fb_output_path,
21102109
)
21112110

2112-
binary_path = (
2111+
binary_path = normalize_path_separator(
21132112
fb_output_path
21142113
if config.is_fbcode()
21152114
else cpp_builder.get_target_file_path()

torch/_inductor/cpp_builder.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ def run_command_line(cmd_line, cwd=None):
244244
return status
245245

246246

247+
def normalize_path_separator(orig_path: str) -> str:
248+
if _IS_WINDOWS:
249+
return orig_path.replace(os.sep, "/")
250+
return orig_path
251+
252+
247253
class BuildOptionsBase:
248254
"""
249255
This is the Base class for store cxx build options, as a template.
@@ -1219,7 +1225,7 @@ def format_build_command(
12191225
f"{compiler} {include_dirs_args} {definations_args} {cflags_args} {sources} "
12201226
f"{passthougn_args} /LD /Fe{target_file} /link {libraries_dirs_args} {libraries_args} {ldflags_args} "
12211227
)
1222-
cmd = cmd.replace("\\", "/")
1228+
cmd = normalize_path_separator(cmd)
12231229
else:
12241230
compile_only_arg = "-c" if self._compile_only else ""
12251231
cmd = re.sub(
@@ -1247,7 +1253,7 @@ def format_build_command(
12471253
return command_line
12481254

12491255
def get_target_file_path(self):
1250-
return self._target_file
1256+
return normalize_path_separator(self._target_file)
12511257

12521258
def build(self) -> Tuple[int, str]:
12531259
"""

torch/_inductor/cpu_vec_isa.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ def __hash__(self) -> int:
8989

9090
def check_build(self, code: str) -> bool:
9191
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
92-
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
92+
from torch._inductor.cpp_builder import (
93+
CppBuilder,
94+
CppTorchOptions,
95+
normalize_path_separator,
96+
)
9397

9498
key, input_path = write(
9599
code,
@@ -111,7 +115,9 @@ def check_build(self, code: str) -> bool:
111115
)
112116
try:
113117
# Check if the output file exist, and compile when not.
114-
output_path = x86_isa_help_builder.get_target_file_path()
118+
output_path = normalize_path_separator(
119+
x86_isa_help_builder.get_target_file_path()
120+
)
115121
if not os.path.isfile(output_path):
116122
status, target_file = x86_isa_help_builder.build()
117123

0 commit comments

Comments
 (0)