Skip to content

Commit 87f651c

Browse files
frank-weipytorchmergebot
authored andcommitted
fix cpu test errors (pytorch#124116)
Similar fix is from @int3 but not landed. Credit to @int3 too. Pull Request resolved: pytorch#124116 Approved by: https://github.com/chenyang78
1 parent 2e48b39 commit 87f651c

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

test/inductor/test_aot_inductor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2714,11 +2714,11 @@ def fail_non_abi_compatible_cuda(is_skip=False):
27142714
# FIXME: failed with Segfault while exiting the Python runtime
27152715
"test_scatter_fallback": fail_stack_allocation(is_skip=True),
27162716
# Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978
2717-
"test_scatter_reduce_fallback": fail_stack_allocation(is_skip=True),
2717+
"test_scatter_reduce_fallback": fail_minimal_arrayref_interface(is_skip=True),
27182718
# Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978
2719-
"test_index_put_fallback": fail_stack_allocation(is_skip=True),
2719+
"test_index_put_fallback": fail_minimal_arrayref_interface(is_skip=True),
27202720
# https://github.com/pytorch/pytorch/issues/122984
2721-
"test_index_put_with_none_index": fail_stack_allocation(is_skip=True),
2721+
"test_index_put_with_none_index": fail_minimal_arrayref_interface(is_skip=True),
27222722
# FIXME: failed with Segfault while exiting the Python runtime
27232723
"test_constant": fail_stack_allocation(is_skip=True),
27242724
# C++ compile error, need for aoti_torch___scaled_dot_product_flash_attention_for_cpu

torch/_inductor/codegen/cpp_wrapper_cpu.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -1277,12 +1277,19 @@ def generate_user_defined_triton_kernel(
12771277
def generate_scatter_fallback(
12781278
self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
12791279
):
1280+
# No stack allocation when there is a fallback op
1281+
self.allow_stack_allocation = False
1282+
12801283
# TODO: needs updates to use C shim v2
12811284
# TODO: support other overload for cpp wrapper and remove the below assertions
12821285
if config.abi_compatible:
12831286
# call the ABI shim function instead of the ATen one
12841287
kernel = kernel.replace("at::", "aoti_torch_")
1285-
line = f"{kernel}({output}, {','.join(map(str, inputs))}"
1288+
inputs_wrapped = [f"convert_arrayref_tensor_to_tensor({x})" for x in inputs]
1289+
line = f"{kernel}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}"
1290+
else:
1291+
line = f"{kernel}({output}, {','.join(map(str, inputs))}"
1292+
12861293
if python_kernel_name == "aten.scatter_":
12871294
if src_is_tensor:
12881295
if reduce:
@@ -1297,22 +1304,40 @@ def generate_scatter_fallback(
12971304
self.writeline(line)
12981305

12991306
def generate_index_put_fallback(self, kernel, x, indices, values, accumulate):
1307+
# No stack allocation when there is a fallback op
1308+
self.allow_stack_allocation = False
1309+
13001310
# TODO: needs updates to use C shim v2
13011311
if config.abi_compatible:
13021312
# See the comment in codegen_reinterpret_view about why having something like
13031313
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
13041314
# tensor prematurely deallocated, thus this std::vector().data() trick here.
13051315
indices_str = (
1306-
f"std::vector<AtenTensorHandle>{{{', '.join(indices)}}}.data()"
1316+
"std::vector<AtenTensorHandle>{"
1317+
+ (
1318+
", ".join(
1319+
[f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices]
1320+
)
1321+
)
1322+
+ "}.data()"
13071323
)
1308-
args = [x, indices_str, str(len(indices)), values, accumulate]
1324+
args = [
1325+
f"convert_arrayref_tensor_to_tensor({x})",
1326+
indices_str,
1327+
str(len(indices)),
1328+
f"convert_arrayref_tensor_to_tensor({values})",
1329+
accumulate,
1330+
]
1331+
args.insert(
1332+
0, f"convert_arrayref_tensor_to_tensor({x})"
1333+
) # set x as the output tensor, this fallback mutates x.
13091334
else:
13101335
indices_str = (
13111336
f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}"
13121337
)
13131338
args = [x, indices_str, values, accumulate]
1339+
args.insert(0, x) # set x as the output tensor, this fallback mutates
13141340

1315-
args.insert(0, x) # set x as the output tensor, this fallback mutates x.
13161341
self.writeline(self.wrap_kernel_call(kernel, args))
13171342

13181343
def add_benchmark_harness(self, output):

0 commit comments

Comments
 (0)