@@ -1277,12 +1277,19 @@ def generate_user_defined_triton_kernel(
1277
1277
def generate_scatter_fallback (
1278
1278
self , output , inputs , kernel , python_kernel_name , src_is_tensor , reduce , kwargs
1279
1279
):
1280
+ # No stack allocation when there is a fallback op
1281
+ self .allow_stack_allocation = False
1282
+
1280
1283
# TODO: needs updates to use C shim v2
1281
1284
# TODO: support other overload for cpp wrapper and remove the below assertions
1282
1285
if config .abi_compatible :
1283
1286
# call the ABI shim function instead of the ATen one
1284
1287
kernel = kernel .replace ("at::" , "aoti_torch_" )
1285
- line = f"{ kernel } ({ output } , { ',' .join (map (str , inputs ))} "
1288
+ inputs_wrapped = [f"convert_arrayref_tensor_to_tensor({ x } )" for x in inputs ]
1289
+ line = f"{ kernel } (convert_arrayref_tensor_to_tensor({ output } ), { ',' .join (inputs_wrapped )} "
1290
+ else :
1291
+ line = f"{ kernel } ({ output } , { ',' .join (map (str , inputs ))} "
1292
+
1286
1293
if python_kernel_name == "aten.scatter_" :
1287
1294
if src_is_tensor :
1288
1295
if reduce :
@@ -1297,22 +1304,40 @@ def generate_scatter_fallback(
1297
1304
self .writeline (line )
1298
1305
1299
1306
def generate_index_put_fallback (self , kernel , x , indices , values , accumulate ):
1307
+ # No stack allocation when there is a fallback op
1308
+ self .allow_stack_allocation = False
1309
+
1300
1310
# TODO: needs updates to use C shim v2
1301
1311
if config .abi_compatible :
1302
1312
# See the comment in codegen_reinterpret_view about why having something like
1303
1313
# RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding
1304
1314
# tensor prematurely deallocated, thus this std::vector().data() trick here.
1305
1315
indices_str = (
1306
- f"std::vector<AtenTensorHandle>{{{ ', ' .join (indices )} }}.data()"
1316
+ "std::vector<AtenTensorHandle>{"
1317
+ + (
1318
+ ", " .join (
1319
+ [f"convert_arrayref_tensor_to_tensor({ ind } )" for ind in indices ]
1320
+ )
1321
+ )
1322
+ + "}.data()"
1307
1323
)
1308
- args = [x , indices_str , str (len (indices )), values , accumulate ]
1324
+ args = [
1325
+ f"convert_arrayref_tensor_to_tensor({ x } )" ,
1326
+ indices_str ,
1327
+ str (len (indices )),
1328
+ f"convert_arrayref_tensor_to_tensor({ values } )" ,
1329
+ accumulate ,
1330
+ ]
1331
+ args .insert (
1332
+ 0 , f"convert_arrayref_tensor_to_tensor({ x } )"
1333
+ ) # set x as the output tensor, this fallback mutates x.
1309
1334
else :
1310
1335
indices_str = (
1311
1336
f"{ self .open_bracket } { ', ' .join (indices )} { self .closed_bracket } "
1312
1337
)
1313
1338
args = [x , indices_str , values , accumulate ]
1339
+ args .insert (0 , x ) # set x as the output tensor, this fallback mutates
1314
1340
1315
- args .insert (0 , x ) # set x as the output tensor, this fallback mutates x.
1316
1341
self .writeline (self .wrap_kernel_call (kernel , args ))
1317
1342
1318
1343
def add_benchmark_harness (self , output ):
0 commit comments