Skip to content

Commit

Permalink
[ROCm] Fix build break in xla/service/gpu/ir_emitter_triton_rocm.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns authored and Ruturaj4 committed May 30, 2024
1 parent ebf4374 commit ea1b8a6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions xla/service/gpu/ir_emitter_triton_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,23 @@ absl::Status CreateTritonPipeline(
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), config.num_warps, threadsPerWarp,
config.num_ctas));
pm.addPass(mt::gpu::createCoalescePass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createOptimizeThreadLocalityPass());
pm.addPass(mt::gpu::createAccelerateMatmulPass(ccAsInt));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createTritonGPUCoalesce());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul({ccAsInt}));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
// TODO ROCm Check if we want to compare MI100 and greater
pm.addPass(mt::gpu::createOptimizeDotOperandsPass(true));
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mlir::createCSEPass());
pm.addPass(mt::gpu::createTritonGPUPipeline(
{config.num_stages, config.num_warps, config.num_ctas, ccAsInt}));
pm.addPass(mt::gpu::createTritonGPUPipeline({config.num_stages, config.num_warps,
config.num_ctas, ccAsInt}));
pm.addPass(mt::gpu::createTritonGPUPrefetch());

// TODO ROCm Check if we want to compare MI100 and greater
pm.addPass(mt::gpu::createOptimizeDotOperandsPass(true));
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());
Expand Down

0 comments on commit ea1b8a6

Please sign in to comment.