From f6f77a158b8a3e523472118142afe5e560c370e1 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 21 Feb 2025 19:20:54 +0100 Subject: [PATCH] OpenXLA-specific changes --- BUILD | 928 +++++++++++++++++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 6 +- .../TritonToTritonGPU/TritonGPUConversion.cpp | 12 + lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 + lib/Dialect/TritonGPU/IR/Ops.cpp | 9 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 47 +- .../Pipeliner/MatmulLoopPipeline.cpp | 9 +- .../Pipeliner/PipeliningUtility.cpp | 7 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 26 +- .../Transforms/RemoveLayoutConversions.cpp | 37 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 30 +- .../Transforms/FenceInsertion.cpp | 17 +- python/BUILD | 79 ++ python/src/ir.cc | 7 +- python/src/passes.cc | 2 +- python/test/regression/BUILD | 26 + python/test/unit/BUILD | 197 ++++ python/test/unit/language/test_core.py | 19 + python/test/unit/runtime/test_peer_access.py | 24 + python/triton/_C/include | 2 +- python/triton/backends/__init__.py | 7 +- python/triton/language/core.py | 8 +- python/triton/runtime/build.py | 37 - test/BUILD | 76 ++ test/Conversion/amd/async_ops_to_llvm.mlir | 10 +- test/TritonGPU/accelerate-matmul.mlir | 18 + .../amd/accelerate-amd-matmul-mfma.mlir | 4 +- .../TritonGPU/amd/amd-convert-buffer-ops.mlir | 2 +- test/TritonGPU/canonicalize.mlir | 16 + test/TritonGPU/combine.mlir | 11 +- test/TritonGPU/prefetch.mlir | 20 + .../samples/simulated-grouped-gemm.mlir | 194 ++-- third_party/amd/BUILD | 268 +++++ .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- .../lib/TritonAMDGPUToLLVM/TargetUtils.cpp | 2 - third_party/f2reduce/BUILD | 31 + third_party/nvidia/BUILD | 319 ++++++ third_party/nvidia/backend/BUILD | 30 + third_party/nvidia/backend/cuda_utils.cc | 929 ++++++++++++++++++ third_party/nvidia/backend/driver.c | 421 -------- third_party/nvidia/backend/driver.py | 494 ++-------- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 9 + third_party/nvidia/language/cuda/BUILD | 13 + .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 34 +- .../DotOpToLLVM/MMAv5.cpp | 13 +- .../DotOpToLLVM/WGMMA.cpp | 2 +- .../ElementwiseOpToLLVM.cpp | 111 ++- third_party/proton/BUILD | 130 +++ third_party/proton/proton/_C/include | 2 +- unittest/BUILD | 144 +++ 50 files changed, 3804 insertions(+), 1043 deletions(-) create mode 100644 BUILD create mode 100644 python/BUILD create mode 100644 python/test/regression/BUILD create mode 100644 python/test/unit/BUILD create mode 100644 python/test/unit/runtime/test_peer_access.py delete mode 100644 python/triton/runtime/build.py create mode 100644 test/BUILD create mode 100644 third_party/amd/BUILD create mode 100644 third_party/f2reduce/BUILD create mode 100644 third_party/nvidia/BUILD create mode 100644 third_party/nvidia/backend/BUILD create mode 100644 third_party/nvidia/backend/cuda_utils.cc delete mode 100644 third_party/nvidia/backend/driver.c create mode 100644 third_party/nvidia/language/cuda/BUILD create mode 100644 third_party/proton/BUILD create mode 100644 unittest/BUILD diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..e662397f1f99 --- /dev/null +++ b/BUILD @@ -0,0 +1,928 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_canonicalize_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/IR", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/IR/TritonCanonicalize.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/IR/Canonicalize.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_op_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_op_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "include/triton/Dialect/Triton/IR/OpInterfaces.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "include/triton/Dialect/Triton/IR/OpInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOpInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + # There are so many interdependencies between Dialect and Analysis that we're just compiling + # everything in a single unit. + "lib/Analysis/*.cpp", + ]) + [ + "include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h", # Avoid circular dependency. + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", # Avoid circular dependency. + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + # There are so many interdependencies between Dialect and Analysis that we're just compiling + # everything in a single unit. + "include/triton/Analysis/*.h", + ]) + [ + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + # What is this lone header doing rooted under Conversion? Best to add it to Dialect, but + # it would be better if upstream moved it there. + "include/triton/Conversion/MLIRTypes.h", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_canonicalize_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_gpu_type_interfaces_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_op_interfaces_inc_gen", + ":triton_op_interfaces_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h", + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-logical-op-parentheses", + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + "-Wno-string-conversion", + ], + }), + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]) + [ + "@triton//test:lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp", + ], + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":TritonTools", + ":triton_gpu_transforms_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 421c1e89481f..e8f512c83b90 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -502,15 +502,17 @@ We call each individual tile "rep". "unsigned", "getTotalElemsPerThread", (ins "ArrayRef":$shape), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape); + return toLinearEncoding($_attr, shape).getTotalElemsPerThread(shape); }]>, InterfaceMethod<"Return element size per thread in each dimension.", "SmallVector", "getElemsPerThread", (ins "ArrayRef":$shape), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return toLinearEncoding($_self, shape).getElemsPerThread(shape); + return toLinearEncoding($_attr, shape).getElemsPerThread(shape); }]>, // Interface for the meta information about the multiple thread hierarchy. InterfaceMethod<"Get the shape of the warps per CTA.", diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 773c01e4a2a0..93e4fb7fa91e 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -58,6 +58,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining arguments that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Argument rematerialization should not happen in Triton " "-> TritonGPU conversion"); return {}; @@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // convert origValue to newValue addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining uses of values that have been converted to a new type. + // We use this to rewrite triton_xla.sparse_dot in a separate pass after + // 'convert-triton-to-tritongpu'. + return builder.create(loc, tensorType, + inputs); llvm_unreachable("Source rematerialization should not happen in Triton -> " "TritonGPU Conversion"); return {}; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dbf008c1e6e8..7f130ce9a56e 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef shape) { } unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape) { + if (auto distLayout = mlir::dyn_cast(layout)) { + return distLayout.getTotalElemsPerThread(shape); + } return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape); } SmallVector getElemsPerThread(Attribute layout, ArrayRef shape) { + if (auto distLayout = mlir::dyn_cast(layout)) { + return distLayout.getElemsPerThread(shape); + } return toLinearEncoding(layout, shape).getElemsPerThread(shape); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index c3d8ff49407f..78076474f4cc 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -159,6 +159,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); @@ -221,8 +226,8 @@ struct CanonicalizeConvertFromConvert // heuristic to accommodate fused attention. auto srcType = op.getSrc().getType(); auto dstType = op.getType(); - if (mlir::isa(dstType.getEncoding()) && - mlir::isa(srcType.getEncoding())) + if (mlir::isa_and_nonnull(dstType.getEncoding()) && + mlir::isa_and_nonnull(srcType.getEncoding())) return failure(); Operation *arg = op.getSrc().getDefiningOp(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index c0396621bdea..f49c5cfd0511 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -21,8 +21,6 @@ namespace mlir { namespace triton { namespace gpu { -namespace { - // Get the highest version supported for the hardware and the dot. static int getMMAVersionSafe(int computeCapability, DotOp op) { // List supported mma version in order of preference. @@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { return 0; } -SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, - int numWarps) { +SmallVector +warpsPerTileV2(Operation *dotOp, const ArrayRef shape, int numWarps) { auto rank = shape.size(); // Early exit for batched matmul if (rank == 3) @@ -112,10 +110,10 @@ SmallVector warpsPerTileV2(DotOp dotOp, const ArrayRef shape, } SmallVector -warpsPerTileV3(DotOp dotOp, const ArrayRef shape, int numWarps, +warpsPerTileV3(Operation *dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); + mlir::getForwardSlice(dotOp->getResult(0), &slices); // Contains a chained dot. We prefer to assign warps to one axis // to facilitate use cases like flash attention, allowing reductions within // the same warp. @@ -181,6 +179,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } @@ -204,7 +217,7 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) { } SmallVector -getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, +getWarpsPerTile(Operation* dotOp, const ArrayRef shape, int version, int numWarps, const SmallVector &instrShape) { switch (version) { case 2: @@ -218,6 +231,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef shape, int version, } static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } + return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || @@ -237,7 +260,7 @@ static bool bwdFilter(Operation *op) { // result, kwidth can be the bitwidth of the lower precision primitive. // Conversely, in the downcasting scenario, no reordering is performed, // making it directory use the lower precision primitive. -static int computeOrigBitWidth(Value x) { +int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; SetVector slice; @@ -257,6 +280,9 @@ static int computeOrigBitWidth(Value x) { } return origBitWidth; } +// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity +// extension. +namespace { class BlockedToMMA : public mlir::OpRewritePattern { int computeCapability; @@ -1147,6 +1173,11 @@ class TritonGPUAccelerateMatmulPass } }; +Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter, + int opIdx, bool allowTranspose) { + return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose); +} + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 9962112e9389..9224cc45325a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -131,6 +131,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc, Value zero = builder.createWithStage( forOp.getLoc(), stage, clusterId, 0, 32); + // Replace the load with insert/extract slice. builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -524,7 +525,8 @@ assignMemoryLayouts(scf::ForOp &forOp, bool isTMALoad = isa(op); - loadsToPipeline.insert(&op); + // TODO: b/381421713 - Uncomment this once pipelining is fixed. + // loadsToPipeline.insert(&op); LoadInfo loadInfo; for (auto use : users) { if (isa(use)) { @@ -562,6 +564,11 @@ assignMemoryLayouts(scf::ForOp &forOp, getBlockedEncoding(loadOp, axisInfoAnalysis); } } + + // TODO: b/381421713 - Remove this once pipelining is fixed. + if (!loadInfo.sharedEncoding) continue; + loadsToPipeline.insert(&op); + loadToInfo[&op] = loadInfo; } // Make sure all loads in loadsToPipeline are in loadToInfo. diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 6df1c31f3855..04d340153630 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -255,7 +255,12 @@ mlir::triton::maybeGetStageCluster(Operation *op) { } std::pair mlir::triton::getStageCluster(Operation *op) { auto res = maybeGetStageCluster(op); - assert(res.has_value() || "Operation is missing stage & cluster attribute"); + if (!res.has_value()) { // DO NOT SUBMIT + llvm::errs() << "op without stage & cluster:\n"; + op->dump(); + op->getParentOfType().dump(); + } + assert(res.has_value() && "Operation is missing stage & cluster attribute"); return *res; } diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 31dad426e715..0fa2a42864bb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -121,7 +121,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, // opIdx: 0 => a, 1 => b auto type = cast(v.getType()); SmallVector shape{type.getShape().begin(), type.getShape().end()}; - SmallVector offset{0, 0}; + SmallVector offset(shape.size(), 0); Type elementType = type.getElementType(); // k => (prefetchWidth, k - prefetchWidth) @@ -146,8 +146,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMutableMemory(), type.getAllocShape()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -197,6 +203,22 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } + // Propagation through ExpandDims is currently not supported. This blindly + // replaces the encoding with dot encoding & but ExpandDims requires a + // SliceEncoding. This could be rewritten to support it somehow, but I + // don't think it's trivial & it's currently crashing. + if (isa(op)) { + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { // NYI for other encodings, for example if we have transpose diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 99b60c07bdbe..4cffc3d217ec 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -165,6 +165,7 @@ class LayoutRematerialization { SetVector opToDelete; FuncOp funcOp; DominanceInfo domInfo; + PostDominanceInfo postDomInfo; }; void LayoutRematerialization::addRematValue(Value old, Attribute encoding, @@ -1120,12 +1121,40 @@ void LayoutRematerialization::hoistConvertDotOperand( ConvertLayoutOp convertOp) { auto targetType = convertOp.getType(); // The pass is targeted to Nvidia mma/wgmma dot operands + + // Partial cherry-pick of https://github.com/triton-lang/triton/pull/5475. + // Path 2 in b/391692127#comment28. Added check for parent being a for loop. + auto canBePipelined = [&](ConvertLayoutOp convertOp) { + auto parent = dyn_cast(convertOp->getParentOp()); + if (!parent) + return false; + + // Find all the dot-like ops in the for loop that have a nvidia dot operand + // encoding on the lhs and check if any of them post-dominates the load + + // cvt + SmallVector dotLikeOps; + parent->walk([&](Operation *op) { + if (!isa(op)) + return; + auto opType = dyn_cast(op->getOperand(0).getType()); + if (!opType) + return; + auto dotEnc = dyn_cast(opType.getEncoding()); + if (!dotEnc) + return; + if (isa(dotEnc.getParent())) + dotLikeOps.push_back(op); + }); + if (dotLikeOps.empty()) + return false; + return llvm::any_of(dotLikeOps, [&](Operation *dot) { + return postDomInfo.postDominates(dot, convertOp); + }); + }; + // We move convert #dot_operand next to their loads. This is done // so that it's then easy to pipeline these loads - // TODO: Perhaps we should do this whenever convertOp is within a loop - - auto dotEnc = dyn_cast(targetType.getEncoding()); - if (!(dotEnc && isa(dotEnc.getParent()))) + if (!canBePipelined(convertOp)) return; // We hoist over any operation that can be done without data movement between diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 4239131701fd..0eedfc74feec 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1022,18 +1022,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; - auto dotOpEnc = dyn_cast( - cast(user->getResult(0).getType()) - .getEncoding()); - if (!dotOpEnc) + auto enc = + cast(user->getResult(0).getType()).getEncoding(); + if (isa(enc)) { + auto srcTy = cast(val.getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), cast(enc), + srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false); + } else if (enc.getAbstractAttribute().getName().str() == + "triton.gpu.sparse_dot_meta_encoding") { + auto srcTy = cast(val.getType()); + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1, + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding())); + } else { return std::nullopt; - auto srcTy = cast(val.getType()); - auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SwizzledSharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, - bitWidth, /*needTrans=*/false); + } } // Check that the shared encodings needed by the users are compatible. if (attr != nullptr && attr != tempAttr) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fc34ddda76b1..c8cbe69742d9 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -41,8 +41,10 @@ struct FenceInsertionPass if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) return; ModuleOp mod = getOperation(); + DenseSet> trace; mod.walk([&](Operation *op) { - if (!isa(op)) + if (!isa(op) && + op->getName().getStringRef() != "triton_xla.sparse_dot") return WalkResult::advance(); OpBuilder builder(op); auto a = op->getOperand(0); @@ -51,8 +53,8 @@ struct FenceInsertionPass cast(op->getResult(0).getType()).getEncoding()); if (!mmaEncoding || !mmaEncoding.isHopper()) return WalkResult::advance(); - bool aDependsOnShared = dependOnSharedEncOperand(a); - bool bDependsOnShared = dependOnSharedEncOperand(b); + bool aDependsOnShared = dependOnSharedEncOperand(a, trace); + bool bDependsOnShared = dependOnSharedEncOperand(b, trace); if (!aDependsOnShared && !bDependsOnShared) return WalkResult::advance(); Operation *fence = builder.create( @@ -73,8 +75,7 @@ struct FenceInsertionPass } private: - bool dependOnSharedEncOperand(Value operand) { - static DenseSet> trace; + bool dependOnSharedEncOperand(Value operand, DenseSet> &trace) { auto op = operand.getDefiningOp(); // avoid redundant insertion if (op && isa(op)) @@ -89,7 +90,7 @@ struct FenceInsertionPass // op and not BlockArgument if (op && !isa(operand)) { for (auto v : op->getOperands()) { - if (dependOnSharedEncOperand(v)) + if (dependOnSharedEncOperand(v, trace)) return true; } } @@ -104,7 +105,7 @@ struct FenceInsertionPass auto iterOperands = forOp.getInitArgs(); if (argNum == 0) return false; - if (dependOnSharedEncOperand(iterOperands[argNum - 1])) + if (dependOnSharedEncOperand(iterOperands[argNum - 1], trace)) return true; // yield auto yieldOp = forOp.getBody()->getTerminator(); @@ -117,7 +118,7 @@ struct FenceInsertionPass else trace.insert(entry); - if (dependOnSharedEncOperand(v)) + if (dependOnSharedEncOperand(v, trace)) return true; } else if (auto whileOp = dyn_cast(argOwner)) { assert(false && "FenceInsertionPass does not supported WhileOp"); diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..247b8cda2103 --- /dev/null +++ b/python/BUILD @@ -0,0 +1,79 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Instrumentation", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + "@triton//third_party/proton:ProtonIRDialect", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/src/ir.cc b/python/src/ir.cc index 14fec22e5889..e873ecad5dbc 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -1865,6 +1865,11 @@ void init_triton_ir(py::module &&m) { if (showStacktraces) { context->disableMultithreading(); } + // DO NOT SUBMIT + llvm::errs() << "showRemarks: " << showRemarks << "\n"; + llvm::errs() << "showWarnings: " << showWarnings << "\n"; + llvm::errs() << "showStacktraces: " << showStacktraces << "\n"; + llvm::errs() << "showOperations: " << showOperations << "\n\n"; if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); }); diff --git a/python/src/passes.cc b/python/src/passes.cc index 619ece2e3455..b3ed20b8d5a0 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -80,7 +80,7 @@ void init_triton_passes_ttgpuir(py::module &&m) { void init_triton_passes_convert(py::module &&m) { using namespace mlir; - ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); + ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..89f832050e5a --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,197 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_blaze_config_cuda": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test + "test_debug.py", # TODO(b/374733875): fix failing test. Also see b/374733872. + "language/test_compile_only.py", # TODO(b/394497996): enable test, when CUDA version in g3 supports Blackwell + "test_perf_warning.py", # No backtraces in non-debug builds. +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/test_core.py", + "language/test_mxfp.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "cuda/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "cuda/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/test_core.py", + "language/test_mxfp.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["runtime/test_peer_access.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Requires 2 GPUs to run +pytest_test( + name = "runtime/test_peer_access", + size = "large", + srcs = ["conftest.py"], + tags = [ + "config-cuda-only", + "requires-gpu-sm90:2", + ], + target_compatible_with = _requires_config_cuda, + tests = ["runtime/test_peer_access.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b38f88eef62e..ebef420b5169 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4242,6 +4242,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/test/unit/runtime/test_peer_access.py b/python/test/unit/runtime/test_peer_access.py new file mode 100644 index 000000000000..873f61b88fa3 --- /dev/null +++ b/python/test/unit/runtime/test_peer_access.py @@ -0,0 +1,24 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +def test_peer_access(device): + if not hasattr(torch, device): + pytest.skip(f"{device} does not support peer access") + if getattr(torch, device).device_count() < 2: + pytest.skip("need at least 2 devices to test peer access") + + @triton.jit + def device_accumulate(my_ptr, peer_ptr): + tl.store(my_ptr, tl.load(my_ptr) + tl.load(peer_ptr)) + + my_tensor = torch.randn(1, device=f"{device}:0") + peer_tensor = torch.randn(1, device=f"{device}:1") + expected = my_tensor + peer_tensor.to(device=f"{device}:0") + + device_accumulate[(1,1,1)](my_tensor, peer_tensor) + + torch.testing.assert_close(my_tensor, expected) diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 869ff9ba5997..6128d81d4819 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -794,7 +794,7 @@ def __str__(self) -> str: @builtin def __add__(self, other, _builder=None): - return add(self, other, sanitize_overflow=True, _builder=_builder) + return add(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __radd__(self, other, _builder=None): @@ -810,7 +810,7 @@ def __rsub__(self, other, _builder=None): @builtin def __mul__(self, other, _builder=None): - return mul(self, other, sanitize_overflow=True, _builder=_builder) + return mul(self, other, sanitize_overflow=False, _builder=_builder) @builtin def __rmul__(self, other, _builder=None): @@ -2177,7 +2177,7 @@ def where(condition, x, y, _builder=None): @builtin -def add(x, y, sanitize_overflow: constexpr = True, _builder=None): +def add(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.add(x, y, sanitize_overflow, _builder) @@ -2191,7 +2191,7 @@ def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): @builtin -def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): +def mul(x, y, sanitize_overflow: constexpr = False, _builder=None): x = _unwrap_if_constexpr(x) y = _unwrap_if_constexpr(y) return semantic.mul(x, y, sanitize_overflow, _builder) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py deleted file mode 100644 index 1b76548d43a7..000000000000 --- a/python/triton/runtime/build.py +++ /dev/null @@ -1,37 +0,0 @@ -import sysconfig -import os -import shutil -import subprocess - - -def _build(name, src, srcdir, library_dirs, include_dirs, libraries): - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - # try to avoid setuptools if possible - cc = os.environ.get("CC") - if cc is None: - # TODO: support more things here. - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") - # This function was renamed and made public in Python 3.10 - if hasattr(sysconfig, 'get_default_scheme'): - scheme = sysconfig.get_default_scheme() - else: - scheme = sysconfig._get_default_scheme() - # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install - # path changes to include 'local'. This change is required to use triton with system-wide python. - if scheme == 'posix_local': - scheme = 'posix_prefix' - py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] - custom_backend_dirs = set(os.getenv(var) for var in ('TRITON_CUDACRT_PATH', 'TRITON_CUDART_PATH')) - include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] - # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 - cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] - cc_cmd += [f'-l{lib}' for lib in libraries] - cc_cmd += [f"-L{dir}" for dir in library_dirs] - cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] - subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) - return so diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..2f6b33a6ad99 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,76 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:non_prod"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "@llvm-project//llvm:llc", +# "@llvm-project//llvm:opt", +# "@llvm-project//mlir:mlir-translate", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# # broken, offending change reverted in +# # https://github.com/triton-lang/triton/commit/3ed479f2f91a1d94dacb547115d357f5ce3219d8 +# "Conversion/reduce_to_llvm.mlir", +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/amd/amd-instruction-sched.mlir", # AMD-specific, broken with -debug-only. +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# # Broken between https://github.com/triton-lang/triton/commit/0dc2154e34ad0eb8d60ff2534755954aa8c8f20e +# # and https://github.com/triton-lang/triton/commit/196a08f04b92fcf0e52015d3b1068c18e4eea5b5 +# "TritonGPU/loop-pipeline.mlir", +# # Currently disabled because of cherry-pick in RemoveLayoutConversions.cpp +# "TritonGPU/remove-layout-combine.mlir", +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +exports_files(srcs = [ + "lib/Dialect/TritonGPU/TestTC05MMAPipeline.cpp", +]) diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index 7a86cecc8dd5..805430e9a81c 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -82,17 +82,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar tt.func public @async_wait(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) { - // The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on - // CHECK: rocdl.waitcnt -49168 + // The swaitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on + // CHECK: rocdl.s.waitcnt -49168 // CHECK: rocdl.barrier ttg.async_wait {num = 0 : i32} - // CHECK: rocdl.waitcnt -49167 + // CHECK: rocdl.s.waitcnt -49167 // CHECK: rocdl.barrier ttg.async_wait {num = 1 : i32} - // CHECK: rocdl.waitcnt -2 + // CHECK: rocdl.s.waitcnt -2 // CHECK: rocdl.barrier ttg.async_wait {num = 62 : i32} - // CHECK: rocdl.waitcnt -1 + // CHECK: rocdl.s.waitcnt -1 // CHECK: rocdl.barrier ttg.async_wait {num = 63 : i32} tt.return diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 1627487b1af7..51e29adcc2e2 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -431,3 +431,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return %d : tensor<128x128xf32, #blocked> } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #ttg.blocked +// CHECK-DAG: #mma = #ttg.nvidia_mma<{versionMajor = 3 +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = ttg.local_alloc + // CHECK: %[[RHS_CVT:.*]] = ttg.convert_layout {{.*}} #ttg.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = ttg.local_alloc %[[RHS_CVT]] + // CHECK: ttng.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir index e47b109863df..95164a42152b 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir @@ -1,5 +1,5 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=0' | FileCheck %s --check-prefixes MFMA0,CHECK -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=16' | FileCheck %s --check-prefixes MFMA16,CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=0' | FileCheck %s --check-prefixes MFMA0,CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=16' | FileCheck %s --check-prefixes MFMA16,CHECK #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}> // CHECK-LABEL: mfma_dot_fp8e5m2 diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 75c3b4b205ca..15c77088318a 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx940'| FileCheck %s +// RUN: triton-opt %s -split-input-file --tritonamdgpu-convert-buffer-ops='arch-generation-name=gfx942'| FileCheck %s #blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 7af051dca5f8..1ace41640154 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -215,3 +215,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.return %b : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> } } + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #ttg.blocked +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = ttg.convert_layout %in : tensor<256x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = ttg.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared1, #smem> + // CHECK: %[[ALLOC:.*]] = ttg.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !ttg.memdesc<256x32xf32, #shared1, #smem> + } +} // end module diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index a51900962f40..d369dd6b478b 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2380,12 +2380,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 %c4096_i32 = arith.constant 4096 : i32 - // CHECK: %[[F:.+]]:4 = scf.for + // CHECK: %[[F:.+]]:3 = scf.for // CHECK: %[[R:.+]] = arith.addf // CHECK: arith.addf - // CHECK: scf.yield %{{.+}}, %{{.+}}, %{{.+}}, %[[R]] + // CHECK: scf.yield %{{.+}}, %{{.+}}, %[[R]] // CHECK: } - // CHECK: tt.return %[[F]]#3, %[[F]]#1, %[[F]]#2 + // CHECK: tt.return %[[F]]#2, %[[F]]#1, %[[F]]#0 %1:3 = scf.for %arg0 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg1 = %cst, %arg3 = %cst_0, %arg4 = %cst) -> (tensor<32xf32, #blocked1>, tensor<32xf32, #blocked>, tensor<32xf32, #blocked1>) : i32 { %4 = arith.addf %arg1, %cst : tensor<32xf32, #blocked1> %5 = ttg.convert_layout %4 : tensor<32xf32, #blocked1> -> tensor<32xf32, #blocked> @@ -3339,6 +3339,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { // CHECK: tt.func @propagate_dot_op_to_constant_above_for() // CHECK: arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + // CHECK: tt.elementwise_inline_asm + // CHECK: scf.for + // CHECK: tt.dot tt.func @propagate_dot_op_to_constant_above_for() -> tensor<32x128xf32, #mma> { %cst = arith.constant dense<1.000000e+00> : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> @@ -3346,8 +3349,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %c32_i32 = arith.constant 32 : i32 %c128_i32 = arith.constant 128 : i32 + %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %loop:1 = scf.for %arg2 = %c0_i32 to %c128_i32 step %c32_i32 iter_args(%arg0 = %cst_1) -> (tensor<32x128xf32, #mma>) : i32 { - %0 = tt.elementwise_inline_asm "cvt.rna.tf32.f32 $0, $1;" {constraints = "=r,r", packed_element = 1 : i32, pure = true} %cst : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %1 = ttg.convert_layout %0 : tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %2 = ttg.convert_layout %cst_0 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %3 = tt.dot %2, %1, %arg0, inputPrecision = tf32 : tensor<32x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x128xf32, #mma> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 1274ca9154a0..f741a3187d15 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -244,3 +244,23 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt tt.return %loop#4 : tensor<128x128xf32, #C> } } // end module + + // ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, %arg_rhs: !ttg.memdesc<512x32xf32, #shared, #smem, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable>) : i32 { + %lhs_ll = ttg.local_load %lhs : !ttg.memdesc<16x512xf32, #shared, #smem, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = ttg.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = ttg.local_load %rhs : !ttg.memdesc<512x32xf32, #shared, #smem, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = ttg.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !ttg.memdesc<16x512xf32, #shared, #smem, mutable>, !ttg.memdesc<512x32xf32, #shared, #smem, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/test/TritonGPU/samples/simulated-grouped-gemm.mlir index 618995d2339e..a17d4aa5e60f 100644 --- a/test/TritonGPU/samples/simulated-grouped-gemm.mlir +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -153,115 +153,115 @@ // CHECK: %[[VAL_115:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_113]]#1 : !tt.tensordesc> to !tt.ptr // CHECK: ttng.async_tma_copy_global_to_local %[[VAL_115]]{{\[}}%[[VAL_113]]#6, %[[VAL_109]]] %[[VAL_114]], %[[VAL_110]], %[[VAL_75]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> // CHECK: %[[VAL_116:.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -// CHECK: %[[VAL_117:.*]]:24 = scf.for %[[VAL_118:.*]] = %[[VAL_13]] to %[[VAL_44]] step %[[VAL_10]] iter_args(%[[VAL_119:.*]] = %[[VAL_77]], %[[VAL_120:.*]] = %[[VAL_113]]#0, %[[VAL_121:.*]] = %[[VAL_113]]#1, %[[VAL_122:.*]] = %[[VAL_113]]#2, %[[VAL_123:.*]] = %[[VAL_113]]#3, %[[VAL_124:.*]] = %[[VAL_113]]#4, %[[VAL_125:.*]] = %[[VAL_113]]#5, %[[VAL_126:.*]] = %[[VAL_113]]#6, %[[VAL_127:.*]] = %[[VAL_22]], %[[VAL_128:.*]] = %[[VAL_9]], %[[VAL_129:.*]] = %[[VAL_10]], %[[VAL_130:.*]] = %[[VAL_12]], %[[VAL_131:.*]] = %[[VAL_13]], %[[VAL_132:.*]] = %[[VAL_113]]#7, %[[VAL_133:.*]] = %[[VAL_113]]#8, %[[VAL_134:.*]] = %[[VAL_113]]#9, %[[VAL_135:.*]] = %[[VAL_13]], %[[VAL_136:.*]] = %[[VAL_77]], %[[VAL_137:.*]] = %[[VAL_35]], %[[VAL_138:.*]] = %[[VAL_113]]#2, %[[VAL_139:.*]] = %[[VAL_72]]#0, %[[VAL_140:.*]] = %[[VAL_113]]#5, %[[VAL_141:.*]] = %[[VAL_72]]#1, %[[VAL_142:.*]] = %[[VAL_113]]#6) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) : i32 { -// CHECK: %[[VAL_143:.*]] = arith.subi %[[VAL_44]], %[[VAL_7]] : i32 -// CHECK: %[[VAL_144:.*]] = arith.cmpi slt, %[[VAL_118]], %[[VAL_143]] : i32 -// CHECK: %[[VAL_145:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_45]] : i32 -// CHECK: %[[VAL_146:.*]] = arith.addi %[[VAL_119]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_147:.*]] = arith.select %[[VAL_145]], %[[VAL_13]], %[[VAL_146]] : i32 -// CHECK: %[[VAL_148:.*]] = arith.cmpi eq, %[[VAL_147]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_149:.*]] = arith.andi %[[VAL_144]], %[[VAL_148]] : i1 -// CHECK: %[[VAL_150:.*]]:10 = scf.if %[[VAL_149]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { -// CHECK: %[[VAL_151:.*]] = arith.addi %[[VAL_124]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_152:.*]] = arith.cmpi eq, %[[VAL_151]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_153:.*]] = arith.select %[[VAL_152]], %[[VAL_13]], %[[VAL_151]] : i32 -// CHECK: %[[VAL_154:.*]]:6 = scf.if %[[VAL_152]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { -// CHECK: %[[VAL_155:.*]] = tt.addptr %[[VAL_0]], %[[VAL_43]] : !tt.ptr, i32 -// CHECK: %[[VAL_156:.*]] = arith.muli %[[VAL_132]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_157:.*]] = tt.addptr %[[VAL_46]], %[[VAL_156]] : !tt.ptr, i32 -// CHECK: %[[VAL_158:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 -// CHECK: tt.experimental_tensormap_create %[[VAL_157]], %[[VAL_155]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_158]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () -// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_157]] : !tt.ptr -// CHECK: %[[VAL_159:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_157]] : !tt.ptr to !tt.tensordesc> -// CHECK: %[[VAL_160:.*]] = arith.addi %[[VAL_132]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_161:.*]] = arith.cmpi slt, %[[VAL_160]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_162:.*]] = arith.select %[[VAL_161]], %[[VAL_160]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_163:.*]] = tt.addptr %[[VAL_1]], %[[VAL_43]] : !tt.ptr, i32 -// CHECK: %[[VAL_164:.*]] = arith.muli %[[VAL_133]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_165:.*]] = tt.addptr %[[VAL_47]], %[[VAL_164]] : !tt.ptr, i32 -// CHECK: %[[VAL_166:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 -// CHECK: tt.experimental_tensormap_create %[[VAL_165]], %[[VAL_163]], {{\[}}%[[VAL_17]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_166]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () -// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_165]] : !tt.ptr -// CHECK: %[[VAL_167:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_165]] : !tt.ptr to !tt.tensordesc> -// CHECK: %[[VAL_168:.*]] = arith.addi %[[VAL_133]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_169:.*]] = arith.cmpi slt, %[[VAL_168]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_170:.*]] = arith.select %[[VAL_169]], %[[VAL_168]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_171:.*]] = tt.addptr %[[VAL_2]], %[[VAL_43]] : !tt.ptr, i32 -// CHECK: %[[VAL_172:.*]] = arith.muli %[[VAL_134]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_173:.*]] = tt.addptr %[[VAL_48]], %[[VAL_172]] : !tt.ptr, i32 -// CHECK: %[[VAL_174:.*]] = arith.muli %[[VAL_34]], %[[VAL_6]] : i64 -// CHECK: tt.experimental_tensormap_create %[[VAL_173]], %[[VAL_171]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_174]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () -// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_173]] : !tt.ptr -// CHECK: %[[VAL_175:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_173]] : !tt.ptr to !tt.tensordesc> -// CHECK: %[[VAL_176:.*]] = arith.addi %[[VAL_134]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_177:.*]] = arith.cmpi slt, %[[VAL_176]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_178:.*]] = arith.select %[[VAL_177]], %[[VAL_176]], %[[VAL_13]] : i32 -// CHECK: scf.yield %[[VAL_159]], %[[VAL_167]], %[[VAL_175]], %[[VAL_162]], %[[VAL_170]], %[[VAL_178]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: %[[VAL_117:.*]]:20 = scf.for %[[VAL_118:.*]] = %[[VAL_13]] to %[[VAL_44]] step %[[VAL_10]] iter_args(%[[VAL_119:.*]] = %[[VAL_77]], %[[VAL_120:.*]] = %[[VAL_113]]#0, %[[VAL_121:.*]] = %[[VAL_113]]#1, %[[VAL_122:.*]] = %[[VAL_113]]#2, %[[VAL_123:.*]] = %[[VAL_113]]#3, %[[VAL_124:.*]] = %[[VAL_113]]#4, %[[VAL_125:.*]] = %[[VAL_113]]#5, %[[VAL_126:.*]] = %[[VAL_113]]#6, %[[VAL_127:.*]] = %[[VAL_22]], %[[VAL_128:.*]] = %[[VAL_9]], %[[VAL_129:.*]] = %[[VAL_10]], %[[VAL_130:.*]] = %[[VAL_12]], %[[VAL_131:.*]] = %[[VAL_13]], %[[VAL_132:.*]] = %[[VAL_113]]#7, %[[VAL_133:.*]] = %[[VAL_113]]#8, %[[VAL_134:.*]] = %[[VAL_113]]#9, %[[VAL_135:.*]] = %[[VAL_13]], %[[VAL_136:.*]] = %[[VAL_35]], %[[VAL_137:.*]] = %[[VAL_72]]#0, %[[VAL_138:.*]] = %[[VAL_72]]#1) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, i32, i32) : i32 { +// CHECK: %[[VAL_139:.*]] = arith.subi %[[VAL_44]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_140:.*]] = arith.cmpi slt, %[[VAL_118]], %[[VAL_139]] : i32 +// CHECK: %[[VAL_141:.*]] = arith.cmpi eq, %[[VAL_119]], %[[VAL_45]] : i32 +// CHECK: %[[VAL_142:.*]] = arith.addi %[[VAL_119]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_143:.*]] = arith.select %[[VAL_141]], %[[VAL_13]], %[[VAL_142]] : i32 +// CHECK: %[[VAL_144:.*]] = arith.cmpi eq, %[[VAL_143]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_145:.*]] = arith.andi %[[VAL_140]], %[[VAL_144]] : i1 +// CHECK: %[[VAL_146:.*]]:10 = scf.if %[[VAL_145]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_147:.*]] = arith.addi %[[VAL_124]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_148:.*]] = arith.cmpi eq, %[[VAL_147]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_149:.*]] = arith.select %[[VAL_148]], %[[VAL_13]], %[[VAL_147]] : i32 +// CHECK: %[[VAL_150:.*]]:6 = scf.if %[[VAL_148]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_151:.*]] = tt.addptr %[[VAL_0]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_152:.*]] = arith.muli %[[VAL_132]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_153:.*]] = tt.addptr %[[VAL_46]], %[[VAL_152]] : !tt.ptr, i32 +// CHECK: %[[VAL_154:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_153]], %[[VAL_151]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_154]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_153]] : !tt.ptr +// CHECK: %[[VAL_155:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_153]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_156:.*]] = arith.addi %[[VAL_132]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_157:.*]] = arith.cmpi slt, %[[VAL_156]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_158:.*]] = arith.select %[[VAL_157]], %[[VAL_156]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_159:.*]] = tt.addptr %[[VAL_1]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_160:.*]] = arith.muli %[[VAL_133]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_161:.*]] = tt.addptr %[[VAL_47]], %[[VAL_160]] : !tt.ptr, i32 +// CHECK: %[[VAL_162:.*]] = arith.muli %[[VAL_31]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_161]], %[[VAL_159]], {{\[}}%[[VAL_17]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_162]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_161]] : !tt.ptr +// CHECK: %[[VAL_163:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_161]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_164:.*]] = arith.addi %[[VAL_133]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_165:.*]] = arith.cmpi slt, %[[VAL_164]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_166:.*]] = arith.select %[[VAL_165]], %[[VAL_164]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_167:.*]] = tt.addptr %[[VAL_2]], %[[VAL_43]] : !tt.ptr, i32 +// CHECK: %[[VAL_168:.*]] = arith.muli %[[VAL_134]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_169:.*]] = tt.addptr %[[VAL_48]], %[[VAL_168]] : !tt.ptr, i32 +// CHECK: %[[VAL_170:.*]] = arith.muli %[[VAL_34]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_169]], %[[VAL_167]], {{\[}}%[[VAL_17]], %[[VAL_15]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_170]]], {{\[}}%[[VAL_10]], %[[VAL_10]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_169]] : !tt.ptr +// CHECK: %[[VAL_171:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_169]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_172:.*]] = arith.addi %[[VAL_134]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_173:.*]] = arith.cmpi slt, %[[VAL_172]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_174:.*]] = arith.select %[[VAL_173]], %[[VAL_172]], %[[VAL_13]] : i32 +// CHECK: scf.yield %[[VAL_155]], %[[VAL_163]], %[[VAL_171]], %[[VAL_158]], %[[VAL_166]], %[[VAL_174]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 // CHECK: } else { // CHECK: scf.yield %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_132]], %[[VAL_133]], %[[VAL_134]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 // CHECK: } -// CHECK: %[[VAL_179:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 -// CHECK: %[[VAL_180:.*]] = arith.divsi %[[VAL_179]], %[[VAL_42]] : i32 -// CHECK: %[[VAL_181:.*]] = arith.muli %[[VAL_180]], %[[VAL_14]] : i32 -// CHECK: %[[VAL_182:.*]] = arith.subi %[[VAL_25]], %[[VAL_181]] : i32 -// CHECK: %[[VAL_183:.*]] = arith.minsi %[[VAL_182]], %[[VAL_14]] : i32 -// CHECK: %[[VAL_184:.*]] = arith.remsi %[[VAL_179]], %[[VAL_183]] : i32 -// CHECK: %[[VAL_185:.*]] = arith.addi %[[VAL_181]], %[[VAL_184]] : i32 -// CHECK: %[[VAL_186:.*]] = arith.remsi %[[VAL_179]], %[[VAL_42]] : i32 -// CHECK: %[[VAL_187:.*]] = arith.divsi %[[VAL_186]], %[[VAL_183]] : i32 -// CHECK: %[[VAL_188:.*]] = arith.muli %[[VAL_185]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_189:.*]] = arith.muli %[[VAL_187]], %[[VAL_16]] : i32 -// CHECK: scf.yield %[[VAL_190:.*]]#0, %[[VAL_190]]#1, %[[VAL_190]]#2, %[[VAL_179]], %[[VAL_153]], %[[VAL_188]], %[[VAL_189]], %[[VAL_190]]#3, %[[VAL_190]]#4, %[[VAL_190]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: %[[VAL_175:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_176:.*]] = arith.divsi %[[VAL_175]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_177:.*]] = arith.muli %[[VAL_176]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_178:.*]] = arith.subi %[[VAL_25]], %[[VAL_177]] : i32 +// CHECK: %[[VAL_179:.*]] = arith.minsi %[[VAL_178]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_180:.*]] = arith.remsi %[[VAL_175]], %[[VAL_179]] : i32 +// CHECK: %[[VAL_181:.*]] = arith.addi %[[VAL_177]], %[[VAL_180]] : i32 +// CHECK: %[[VAL_182:.*]] = arith.remsi %[[VAL_175]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_183:.*]] = arith.divsi %[[VAL_182]], %[[VAL_179]] : i32 +// CHECK: %[[VAL_184:.*]] = arith.muli %[[VAL_181]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_185:.*]] = arith.muli %[[VAL_183]], %[[VAL_16]] : i32 +// CHECK: scf.yield %[[VAL_186:.*]]#0, %[[VAL_186]]#1, %[[VAL_186]]#2, %[[VAL_175]], %[[VAL_149]], %[[VAL_184]], %[[VAL_185]], %[[VAL_186]]#3, %[[VAL_186]]#4, %[[VAL_186]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 // CHECK: } else { // CHECK: scf.yield %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]], %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_132]], %[[VAL_133]], %[[VAL_134]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 // CHECK: } -// CHECK: %[[VAL_191:.*]] = arith.addi %[[VAL_130]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_192:.*]] = arith.cmpi slt, %[[VAL_191]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_193:.*]] = arith.select %[[VAL_192]], %[[VAL_191]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_194:.*]] = arith.xori %[[VAL_131]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_195:.*]] = arith.select %[[VAL_192]], %[[VAL_131]], %[[VAL_194]] : i32 -// CHECK: %[[VAL_196:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_193]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: ttng.wait_barrier %[[VAL_196]], %[[VAL_195]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: %[[VAL_197:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_193]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -// CHECK: %[[VAL_198:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_193]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> -// CHECK: %[[VAL_199:.*]] = ttg.memdesc_trans %[[VAL_197]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -// CHECK: %[[VAL_200:.*]] = ttng.warp_group_dot %[[VAL_198]], %[[VAL_199]], %[[VAL_127]], %[[VAL_128]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> -// CHECK: %[[VAL_201:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_200]], %[[VAL_198]], %[[VAL_199]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -// CHECK: %[[VAL_202:.*]] = arith.addi %[[VAL_129]], %[[VAL_10]] : i32 -// CHECK: %[[VAL_203:.*]] = arith.cmpi slt, %[[VAL_202]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_204:.*]] = arith.select %[[VAL_203]], %[[VAL_202]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_205:.*]] = arith.muli %[[VAL_147]], %[[VAL_17]] : i32 -// CHECK: %[[VAL_206:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_204]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: ttng.barrier_expect %[[VAL_206]], 49152, %[[VAL_144]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: %[[VAL_207:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_204]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> -// CHECK: %[[VAL_208:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_209:.*]]#0 : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_208]]{{\[}}%[[VAL_209]]#5, %[[VAL_205]]] %[[VAL_207]], %[[VAL_206]], %[[VAL_144]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> -// CHECK: %[[VAL_210:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_204]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -// CHECK: %[[VAL_211:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_209]]#1 : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_211]]{{\[}}%[[VAL_209]]#6, %[[VAL_205]]] %[[VAL_210]], %[[VAL_206]], %[[VAL_144]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -// CHECK: %[[VAL_212:.*]] = arith.cmpi eq, %[[VAL_135]], %[[VAL_45]] : i32 -// CHECK: %[[VAL_213:.*]] = arith.cmpi ne, %[[VAL_135]], %[[VAL_45]] : i32 -// CHECK: scf.if %[[VAL_212]] { -// CHECK: %[[VAL_214:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_201]]#0, %[[VAL_198]], %[[VAL_199]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -// CHECK: %[[VAL_215:.*]] = arith.truncf %[[VAL_214]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: %[[VAL_187:.*]] = arith.addi %[[VAL_130]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_188:.*]] = arith.cmpi slt, %[[VAL_187]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_189:.*]] = arith.select %[[VAL_188]], %[[VAL_187]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_190:.*]] = arith.xori %[[VAL_131]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_191:.*]] = arith.select %[[VAL_188]], %[[VAL_131]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_192:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_189]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.wait_barrier %[[VAL_192]], %[[VAL_191]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_193:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_189]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_194:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_189]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_195:.*]] = ttg.memdesc_trans %[[VAL_193]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_196:.*]] = ttng.warp_group_dot %[[VAL_194]], %[[VAL_195]], %[[VAL_127]], %[[VAL_128]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_197:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_196]], %[[VAL_194]], %[[VAL_195]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_198:.*]] = arith.addi %[[VAL_129]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_199:.*]] = arith.cmpi slt, %[[VAL_198]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_200:.*]] = arith.select %[[VAL_199]], %[[VAL_198]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_201:.*]] = arith.muli %[[VAL_143]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_202:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_200]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.barrier_expect %[[VAL_202]], 49152, %[[VAL_140]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_203:.*]] = ttg.memdesc_subview %[[VAL_49]]{{\[}}%[[VAL_200]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_204:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_205:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_204]]{{\[}}%[[VAL_205]]#5, %[[VAL_201]]] %[[VAL_203]], %[[VAL_202]], %[[VAL_140]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64> +// CHECK: %[[VAL_206:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_200]], %[[VAL_13]], %[[VAL_13]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_207:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_205]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_207]]{{\[}}%[[VAL_205]]#6, %[[VAL_201]]] %[[VAL_206]], %[[VAL_202]], %[[VAL_140]] : !tt.ptr, !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x256x64> +// CHECK: %[[VAL_208:.*]] = arith.cmpi eq, %[[VAL_135]], %[[VAL_45]] : i32 +// CHECK: %[[VAL_209:.*]] = arith.cmpi ne, %[[VAL_135]], %[[VAL_45]] : i32 +// CHECK: scf.if %[[VAL_208]] { +// CHECK: %[[VAL_210:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_197]]#0, %[[VAL_194]], %[[VAL_195]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_211:.*]] = arith.truncf %[[VAL_210]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} -// CHECK: ttg.local_store %[[VAL_215]], %[[VAL_116]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: ttg.local_store %[[VAL_211]], %[[VAL_116]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> // CHECK: ttng.fence_async_shared {bCluster = false} -// CHECK: %[[VAL_216:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_137]] : !tt.tensordesc> to !tt.ptr -// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_216]]{{\[}}%[[VAL_139]], %[[VAL_141]]] %[[VAL_116]] : !tt.ptr, !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> +// CHECK: %[[VAL_212:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_136]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_212]]{{\[}}%[[VAL_137]], %[[VAL_138]]] %[[VAL_116]] : !tt.ptr, !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> // CHECK: } -// CHECK: scf.yield %[[VAL_147]], %[[VAL_209]]#0, %[[VAL_209]]#1, %[[VAL_209]]#2, %[[VAL_209]]#3, %[[VAL_209]]#4, %[[VAL_209]]#5, %[[VAL_209]]#6, %[[VAL_201]]#0, %[[VAL_213]], %[[VAL_204]], %[[VAL_193]], %[[VAL_195]], %[[VAL_209]]#7, %[[VAL_209]]#8, %[[VAL_209]]#9, %[[VAL_136]], %[[VAL_147]], %[[VAL_138]], %[[VAL_209]]#2, %[[VAL_140]], %[[VAL_209]]#5, %[[VAL_142]], %[[VAL_209]]#6 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 +// CHECK: scf.yield %[[VAL_143]], %[[VAL_205]]#0, %[[VAL_205]]#1, %[[VAL_205]]#2, %[[VAL_205]]#3, %[[VAL_205]]#4, %[[VAL_205]]#5, %[[VAL_205]]#6, %[[VAL_197]]#0, %[[VAL_209]], %[[VAL_200]], %[[VAL_189]], %[[VAL_191]], %[[VAL_205]]#7, %[[VAL_205]]#8, %[[VAL_205]]#9, %[[VAL_119]], %[[VAL_122]], %[[VAL_125]], %[[VAL_126]] : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, i32, i32 // CHECK: } // CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} // CHECK: ttg.local_dealloc %[[VAL_116]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> -// CHECK: %[[VAL_217:.*]] = ttng.warp_group_dot_wait %[[VAL_218:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> -// CHECK: %[[VAL_219:.*]] = ttg.async_wait {num = 0 : i32} -// CHECK: %[[VAL_220:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: ttng.inval_barrier %[[VAL_220]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: %[[VAL_221:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: ttng.inval_barrier %[[VAL_221]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: %[[VAL_222:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> -// CHECK: ttng.inval_barrier %[[VAL_222]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_213:.*]] = ttng.warp_group_dot_wait %[[VAL_214:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_215:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_216:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_216]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_217]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: %[[VAL_218:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> +// CHECK: ttng.inval_barrier %[[VAL_218]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3> // CHECK: ttg.local_dealloc %[[VAL_49]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> // CHECK: ttg.local_dealloc %[[VAL_50]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable> // CHECK: tt.return diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..2d4b2ff257d7 --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,268 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/backends/gpu/codegen/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob( + [ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ], + exclude = [ + "lib/TritonAMDGPUTransforms/MfmaGroup.cpp", # Avoid circular dependency. + ], + ) + [ + # Work around dependencies on private headers. + "lib/TritonAMDGPUToLLVM/SchedInstructions.h", + "lib/TritonAMDGPUToLLVM/TargetInfo.h", + "lib/TritonAMDGPUToLLVM/Utility.h", + ], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":TritonAMDGPU", + ":TritonAMDGPUToLLVM", + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPU", + srcs = glob([ + "lib/Dialect/TritonAMDGPU/**/*.h", + "lib/Dialect/TritonAMDGPU/**/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/TritonAMDGPU/**/*.h", + ]), + includes = [ + "..", + "include", + ], + deps = [ + ":triton_amdgpu_attr_def_inc_gen", + ":triton_amdgpu_dialect_inc_gen", + ":triton_amdgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + # TritonAMDGPUToLLVM and TritonAMDGPUDialectToLLVM have interdependencies, easiest way to + # deal with circular dependencies is to just compile both in a single unit. + "lib/TritonAMDGPUDialectToLLVM/**/*.h", + "lib/TritonAMDGPUDialectToLLVM/**/*.cpp", + ]) + [ + "include/TritonAMDGPUTransforms/MfmaGroup.h", # Avoid circular dependency. + "lib/TritonAMDGPUTransforms/MfmaGroup.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable + ["-Wno-implicit-fallthrough"], + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPU", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:AMDGPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "@triton//third_party/proton:TritonProtonToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_ops_inc_gen", + tbl_outs = [ + ( + [ + "--gen-llvmir-conversions", + ], + "include/Dialect/TritonAMDGPU/IR/OpsConversions.inc", + ), + ( + [ + "--gen-op-decls", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.h.inc", + ), + ( + [ + "--gen-op-defs", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_dialect_inc_gen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_attr_def_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 757802cb912a..491c1da65e54 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1682,7 +1682,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set unsigned waitValue = lowBits | highBits | otherCnts; - rewriter.create(loc, waitValue); + rewriter.create(loc, waitValue); // Drop the result AsyncToken rewriter.replaceOp(op, b.i32_val(0)); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp index 88fda8164d37..10d0b308af1b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp @@ -13,8 +13,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { switch (kind) { case llvm::AMDGPU::GK_GFX950: case llvm::AMDGPU::GK_GFX942: - case llvm::AMDGPU::GK_GFX941: - case llvm::AMDGPU::GK_GFX940: return ISAFamily::CDNA3; case llvm::AMDGPU::GK_GFX90A: return ISAFamily::CDNA2; diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..93829539e1b9 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..6d388b005d2d --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,319 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/backends/gpu:__subpackages__", + # "//third_party/tensorflow/compiler/xla/pjrt:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-return-type", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob( + include = [ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ], + exclude = ["lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp"], + ), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h", + "lib/TritonNVIDIAGPUToLLVM/TargetInfo.h", + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + "lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToLLVM", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//:triton_gpu_attr_inc_gen", + "@triton//third_party/proton:TritonProtonToLLVM", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc new file mode 100644 index 000000000000..3a63d299af08 --- /dev/null +++ b/third_party/nvidia/backend/cuda_utils.cc @@ -0,0 +1,929 @@ +#define PY_SSIZE_T_CLEAN +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cuda.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace { + +struct UniquePyObjectDeleter { + void operator()(PyObject* obj) { Py_DECREF(obj); } +}; +// A unique_ptr for PyObjects that automatically calls Py_DECREF once it goes +// out of scope. +using UniquePyObjectPtr = std::unique_ptr; + +// Raise a python exception if the CUDA result code is not CUDA_SUCCESS. +// Can be called even on threads that do not hold Python's Global Interpreter +// Lock (GIL), as the function will acquire one if needed. +inline bool gpuAssert(CUresult code, const char* file, int line) { + if (code == CUDA_SUCCESS) + return true; + const char* error = nullptr; + cuGetErrorString(code, &error); + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyErr_Format(PyExc_RuntimeError, "Triton Error [CUDA]: %s", error); + PyGILState_Release(gil_state); + return false; +} + +// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +#define CUDA_CHECK_AND_RETURN_NULL(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) \ + return NULL; \ + } while (0) + +// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. +#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ + do { \ + if (!gpuAssert((ans), __FILE__, __LINE__)) { \ + PyEval_RestoreThread(_save); \ + return NULL; \ + } \ + } while (0) + +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + +using cuLaunchKernelEx_t = CUresult (*)(const CUlaunchConfig* config, + CUfunction f, void** kernelParams, + void** extra); + +// Dynamically load the handle to cuLaunchKernelEx. +cuLaunchKernelEx_t getLaunchKernelExHandle() { + // Open the shared library + void* handle = dlopen("libcuda.so.1", RTLD_LAZY); + if (!handle) { + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); + return nullptr; + } + // Clear any existing error + dlerror(); + auto cuLaunchKernelExHandle = + reinterpret_cast(dlsym(handle, "cuLaunchKernelEx")); + // Check for errors + if (const char* dlsym_error = dlerror()) { + PyErr_Format(PyExc_RuntimeError, + "Failed to retrieve cuLaunchKernelEx from libcuda.so: %s", + dlsym_error); + return nullptr; + } + return cuLaunchKernelExHandle; +} + +// Configuration with all the information necessary to launch a compiled +// Triton kernel using the CUDA driver API. +struct TritonLaunchConfig { + // Represents CUDA's 3D ID structure of grids and clusters + struct Dim { + int x; + int y; + int z; + constexpr int size() const { return x * y * z; } + }; + Dim grid; // Number of clusters per grid + Dim cluster; // Number of blocks per cluster + int num_warps; // number of warps per block + int shared_memory; // Size of shared memory in bytes to allocate + CUstream stream; // CUDA Stream on which to launch the kernel + CUfunction function; // Pointer to the kernel to launch + void** params; // Parameters to pass to the kernel +}; + +// Launch a CUDA kernel with the given parameters. Raises a Python exception +// if the kernel launch fails. +PyObject* launchKernel(const TritonLaunchConfig& config) { + // Launching the kernel might take a while and does not use Python APIs, so + // we can release the Global Interpreter Lock so other threads can use Python + // APIs if needed. + Py_BEGIN_ALLOW_THREADS; + const auto& grid = config.grid; + const auto& cluster = config.cluster; + if (grid.size() == 0) { + PyEval_RestoreThread(_save); + Py_RETURN_NONE; + } + if (cluster.size() == 1) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuLaunchKernel( + config.function, grid.x, grid.y, grid.z, 32 * config.num_warps, 1, 1, + config.shared_memory, config.stream, config.params, 0)); + } else { + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = cluster.x; + launchAttr[0].value.clusterDim.y = cluster.y; + launchAttr[0].value.clusterDim.z = cluster.z; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + CUlaunchConfig cu_config; + cu_config.gridDimX = grid.x * cluster.x; + cu_config.gridDimY = grid.y * cluster.y; + cu_config.gridDimZ = grid.z * cluster.z; + cu_config.blockDimX = 32 * config.num_warps; + cu_config.blockDimY = 1; + cu_config.blockDimZ = 1; + cu_config.sharedMemBytes = config.shared_memory; + cu_config.hStream = config.stream; + cu_config.attrs = launchAttr; + cu_config.numAttrs = 2; + // cuLaunchKernelEx was added in CUDA 12, so load it dynamically to be + // able to link on CUDA 11 and earlier. + static cuLaunchKernelEx_t cuLaunchKernelExHandle = + getLaunchKernelExHandle(); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuLaunchKernelExHandle(&cu_config, config.function, config.params, 0)); + } + Py_END_ALLOW_THREADS; + Py_RETURN_NONE; +} + +// Interface used by various PyObject extractors to extract obj into a memory +// location pointed by ptr. Returns true if extraction succeeded, and false +// otherwise. +using ExtractorType = bool (*)(PyObject* obj, void* ptr); + +// Enable peer access if dev_ptr is allocated on a different device than the +// device on which we will execute the kernel. +PyObject* enablePeerAccessIfNecessary(CUdeviceptr dev_ptr) { + CUmemorytype mem_type = CU_MEMORYTYPE_HOST; + CUresult status = cuPointerGetAttribute( + &mem_type, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, dev_ptr); + if (status != CUDA_SUCCESS || mem_type != CU_MEMORYTYPE_DEVICE) { + // Not peer memory + Py_RETURN_NONE; + } + int mem_device_ordinal = 0; + CUDA_CHECK_AND_RETURN_NULL(cuPointerGetAttribute( + &mem_device_ordinal, CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, dev_ptr)); + CUdevice mem_device = 0; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGet(&mem_device, mem_device_ordinal)); + CUdevice compute_device = 0; + CUDA_CHECK_AND_RETURN_NULL(cuCtxGetDevice(&compute_device)); + if (mem_device != compute_device) { + CUcontext mem_ctx = nullptr; + CUDA_CHECK_AND_RETURN_NULL(cuDevicePrimaryCtxRetain(&mem_ctx, mem_device)); + CUresult status = cuCtxEnablePeerAccess(mem_ctx, /*flags=*/0); + if (status == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { + status = CUDA_SUCCESS; + } + CUDA_CHECK_AND_RETURN_NULL(status); + } + Py_RETURN_NONE; +} + +// Extract a CUDA device pointer from a pointer-like PyObject obj, and store +// it to the memory location pointed by ptr. +bool extractPointer(PyObject* obj, void* ptr) { + auto dev_ptr = static_cast(ptr); + if (obj == Py_None) { + *dev_ptr = static_cast(0); // valid nullptr + return true; + } + if (PyLong_Check(obj)) { + *dev_ptr = PyLong_AsUnsignedLongLong(obj); + return true; + } + UniquePyObjectPtr ret(PyObject_CallMethod(obj, "data_ptr", nullptr)); + if (!ret.get()) { + PyErr_Format(PyExc_TypeError, + "Pointer argument must be either uint64 or have data_ptr " + "method, but got %R", + obj); + return false; + } + if (!PyLong_Check(ret.get())) { + PyErr_SetString(PyExc_TypeError, + "data_ptr method of Pointer object must return 64-bit int"); + return false; + } + *dev_ptr = PyLong_AsUnsignedLongLong(ret.get()); + if (PyErr_Occurred()) { + return false; + } + if (*dev_ptr == 0) { + return true; // valid nullptr + } + if (enablePeerAccessIfNecessary(*dev_ptr) == nullptr) { + return false; + } + CUresult status = cuPointerGetAttribute( + dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, *dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) { + PyErr_Format(PyExc_ValueError, + "Pointer argument cannot be accessed from Triton " + "(cpu tensor?)"); + return false; + } else if (status != CUDA_SUCCESS) { + CUDA_CHECK(status); + return false; + } + return true; +} + +// For a given type T, maps to the Python API with signature `U (*)(PyObject*)` +// that can extract values of that type from a PyObject. Note that the return +// type U is not guaranteed to be the same as T, but it can be explicitly casted +// to T. +template +constexpr auto kValueFunction = nullptr; +template +constexpr auto + kValueFunction && + std::is_signed_v && sizeof(T) <= 4>> = + PyLong_AsLong; +template <> +constexpr auto kValueFunction = PyLong_AsLongLong; +template +constexpr auto kValueFunction< + T, std::enable_if_t && std::is_unsigned_v && + sizeof(T) <= 4>> = PyLong_AsUnsignedLong; +template <> +constexpr auto kValueFunction = PyLong_AsUnsignedLongLong; +template +constexpr auto + kValueFunction>> = + PyFloat_AsDouble; + +// Extract a value of type T from obj and store it into memory pointed by ptr. +// Returns true if extraction succeeded, and false otherwise. +template +bool extractValue(PyObject* obj, void* ptr) { + *static_cast(ptr) = static_cast(kValueFunction(obj)); + return PyErr_Occurred() == nullptr; +} + +// Contains information necessary for extracting a certain type from a PyObject. +struct ExtractionInfo { + // Prefixes of types reprs supported by the extractor. + llvm::SmallVector supported_type_repr_prefixes; + std::size_t size; // Size required by the extracted value. + ExtractorType extractor; // Function to call to extract the value. + + // Builds an ExtractionInfo for a given type T and a list of type reprs that + // are backed by that type. + template + static ExtractionInfo build( + std::initializer_list supported_type_reprs, + ExtractorType extractor = extractValue) { + return {supported_type_reprs, sizeof(T), extractor}; + } + + // Checks if the extractor supports extracting a given type repr. + bool supports(llvm::StringRef type_repr) const { + return llvm::any_of( + supported_type_repr_prefixes, + [&](llvm::StringRef prefix) { return type_repr.starts_with(prefix); }); + } +}; + +// Array of supported extractors +const ExtractionInfo kExtractionInfos[]{ + ExtractionInfo::build({"'i8'"}), + ExtractionInfo::build({"'i16'"}), + ExtractionInfo::build({"'i1'", "'i32'"}), + ExtractionInfo::build({"'i64'"}), + ExtractionInfo::build({"'u8'"}), + ExtractionInfo::build({"'u16'"}), + ExtractionInfo::build({"'u1'", "'u32'"}), + ExtractionInfo::build({"'u64'"}), + ExtractionInfo::build({"'fp16'", "'bf16'", "'fp32'", "'f32'"}), + ExtractionInfo::build({"'fp64'"}), + // Note: types are e.g. '*fp32', so no closing quote is intentional. + ExtractionInfo::build({"'*"}, extractPointer), + ExtractionInfo{ + {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None +}; + +// Finds an extractor that supports a given type_repr in the extractor list. +// Returns nullopt if no such extractor is found. +std::optional findExtractor(llvm::StringRef type_repr) { + constexpr std::size_t kNumExtractors = std::size(kExtractionInfos); + static_assert(kNumExtractors < std::numeric_limits::max(), + "Not enough bits in a byte to store the extractor index"); + for (const auto& [idx, info] : llvm::enumerate(kExtractionInfos)) { + if (info.supports(type_repr)) return idx; + } + return std::nullopt; +} + +PyDoc_STRVAR(buildSignatureMetadata__doc__, + R"(buildSignatureMetadata(signature_iterator) -> bytes + +Build a metadata object describing the signature of a kernel. + +This can then be passed as the signature_metadata parameter to the launch() +function. + +:param signature: list of types describing the signature of a kernel, + specialized parameters should be represented with None +:type signature: sequence or iterable +:return: an opaque metadata object which can then be passed to launch() +:rtype: bytes +)"); +PyObject* buildSignatureMetadata(PyObject* self, PyObject* args) { + PyObject* signature = nullptr; + if (!PyArg_ParseTuple(args, "O", &signature)) { + return nullptr; + } + if (!PyIter_Check(signature)) { + PyErr_Format(PyExc_TypeError, + "expected signature to be an iterable, got %R", signature); + return nullptr; + } + + llvm::SmallVector signature_metadata; + while (UniquePyObjectPtr obj_type{PyIter_Next(signature)}) { + UniquePyObjectPtr repr(PyObject_Repr(obj_type.get())); + if (!repr) { + return nullptr; + } + UniquePyObjectPtr repr_str( + PyUnicode_AsEncodedString(repr.get(), "utf-8", "~E~")); + if (!repr_str) { + return nullptr; + } + const char* repr_bytes = PyBytes_AsString(repr_str.get()); + if (!repr_bytes) { + return nullptr; + } + std::optional extractor_idx = findExtractor(repr_bytes); + if (!extractor_idx.has_value()) { + PyErr_Format(PyExc_TypeError, + "unexpected type %R in kernel signature, dir: %R", + obj_type.get(), PyObject_Dir(obj_type.get())); + return nullptr; + } + signature_metadata.push_back(extractor_idx.value()); + } + if (PyErr_Occurred()) { + return nullptr; + } + + return PyBytes_FromStringAndSize(signature_metadata.data(), + signature_metadata.size()); +} + +// Launch a Python callable hook with metadata passed as parameters. +bool launchHook(PyObject* hook, PyObject* metadata) { + if (hook == Py_None) { + return true; + } + UniquePyObjectPtr args(Py_BuildValue("(O)", metadata)); + if (!args) { + return false; + } + UniquePyObjectPtr ret(PyObject_CallObject(hook, args.get())); + return static_cast(ret); +} + +static void ensureCudaContext() { + CUcontext pctx; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + // Ensure device context. + CUdevice device; + CUDA_CHECK(cuDeviceGet(&device, 0)); + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } +} + +PyDoc_STRVAR( + launch__doc__, + R"(launch(gridDimX, gridDimY, gridDimZ, stream, kernel, packed_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, kernel_arg_types, global_scratch, kernel_args) + +Launch a kernel on an Nvidia GPU. + +:param gridDimX: X dimension of the grid +:type gridDimX: signed integer +:param gridDimY: Y dimension of the grid +:type gridDimY: signed integer +:param gridDimZ: Z dimension of the grid +:type gridDimZ: signed integer +:param stream: CUDA Stream to launch on +:type stream: unsigned long integer (pointer) +:param kernel: CUDA kernel to launch +:type kernel: unsigned long integer (pointer) +:param packed_metadata: Kernel metadata, including in sequence: + number of warps, number of CTAs, required bytes of shared memory, + cluster dimensions x, y, and z +:type packed_metadata: 6-tuple +:param hook_args: arguments to pass to the enter and exit hooks +:type hook_args: object +:param launch_enter_hook: hook to call just before launching the kernel +:type launch_enter_hook: callable +:param launch_exit_hook: hook to call just after launching the kernel +:type launch_exit_hook: callable +:param signature_metadata: matadata built from build_signature_metadata +:type signature_metadata: bytes +:param global_scratch: pointer to global scratch memory +:type global_scratch: pointer +:param kernel_args: kernel parameters +:type kernel_args: tuple + +:raises RuntimeError: on kernel launch failure +)"); +PyObject* launch(PyObject* self, PyObject* args) { + ensureCudaContext(); + TritonLaunchConfig config{}; + auto& grid = config.grid; + auto& cluster = config.cluster; + // PyObject* kernel_metadata = nullptr; + PyObject* hook_args = nullptr; + PyObject* launch_enter_hook = nullptr; + PyObject* launch_exit_hook = nullptr; + PyBytesObject* signature_metadata_bytes = nullptr; + PyObject* kernel_args = nullptr; + PyObject* global_scratch = nullptr; + int num_ctas = 0; + if (!PyArg_ParseTuple(args, "iiiKK(iiiiii)OOOSOO", &grid.x, &grid.y, &grid.z, + &config.stream, &config.function, &config.num_warps, + &num_ctas, &config.shared_memory, &cluster.x, + &cluster.y, &cluster.z, &hook_args, &launch_enter_hook, + &launch_exit_hook, &signature_metadata_bytes, + &global_scratch, &kernel_args)) { + return nullptr; + } + if (num_ctas != cluster.size()) { + PyErr_Format( + PyExc_ValueError, + "Expected cluster dimensions (%d, %d, %d) to have a total size of %d", + cluster.x, cluster.y, cluster.z, num_ctas); + return nullptr; + } + llvm::ArrayRef signature_metadata( + PyBytes_AS_STRING(signature_metadata_bytes), + PyBytes_GET_SIZE(signature_metadata_bytes)); + UniquePyObjectPtr fast_kernel_args(PySequence_Fast( + kernel_args, "Expected kernel_args to be a sequence or iterable")); + if (!fast_kernel_args) { + return nullptr; + } + llvm::ArrayRef kernel_args_data( + PySequence_Fast_ITEMS(fast_kernel_args.get()), + PySequence_Fast_GET_SIZE(fast_kernel_args.get())); + + if (signature_metadata.size() != kernel_args_data.size()) { + PyErr_Format(PyExc_TypeError, + "Expected kernel to have %d parameters, but got %d", + signature_metadata.size(), kernel_args_data.size()); + return nullptr; + } + + // +1 for the global scratch pointer. + std::size_t num_params = signature_metadata.size() + 1; + // Use alloca to set up kernel parameters on the stack and avoid dynamic + // memory allocations. + config.params = static_cast(alloca(num_params * sizeof(void*))); + // This loop has to stay in the same function that owns params, since we are + // using alloca to allocate pointers to it on the stack of the function. + std::size_t params_idx = 0; + for (const auto& [converter_idx, arg] : + llvm::zip(signature_metadata, kernel_args_data)) { + if (converter_idx >= std::size(kExtractionInfos)) { + PyErr_SetString(PyExc_ValueError, "corrupted signature metadata"); + return nullptr; + } + const ExtractionInfo& extraction_info = kExtractionInfos[converter_idx]; + if (extraction_info.size == 0) { + continue; // skip adding constexpr parameters + } + config.params[params_idx] = alloca(extraction_info.size); + if (!extraction_info.extractor(arg, config.params[params_idx])) { + return nullptr; + } + ++params_idx; + } + config.params[params_idx] = alloca(sizeof(void*)); + if (!extractPointer(global_scratch, config.params[params_idx])) { + return nullptr; + } + + if (!launchHook(launch_enter_hook, hook_args)) { + return nullptr; + } + + if(!launchKernel(config)) { + return nullptr; + } + + if (!launchHook(launch_exit_hook, hook_args)) { + return nullptr; + } + + Py_RETURN_NONE; +} + +} // namespace + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int max_num_regs; + int multiprocessor_count; + int warp_size; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK_AND_RETURN_NULL( + cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "max_num_regs", max_num_regs, + "multiprocessor_count", multiprocessor_count, "warpSize", + warp_size, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + CUdevice device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGet(&device, 0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( + &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + Py_END_ALLOW_THREADS; + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +#if CUDA_VERSION >= 12000 +typedef CUresult (*cuTensorMapEncodeTiled_t)( + CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, + const cuuint64_t *globalStrides, const cuuint32_t *boxDim, + const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill); +#endif + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so.1"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ + } + +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); + +#if CUDA_VERSION >= 12000 +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); +#endif + +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUlaunchAttribute launchAttr[1]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + CUlaunchConfig config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + config.numAttrs = 1; + config.attrs = launchAttr; + + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + +static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { + long size; + if (!PyArg_ParseTuple(args, "l", &size)) { + return NULL; + } + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS; + + // Ensure we have an active context. + CUcontext ctx = NULL; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); + if (!ctx) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); + } + + // We can't set the fifo size after running a kernel that calls printf. This + // is true even if the set() call is a nop and the new size is the same as the + // old size. + // + // This is unfriendly, so check if the old size matches the new size, and skip + // the set() call if so. + size_t oldSize = 0; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); + if (oldSize != size) { + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); + } + + Py_END_ALLOW_THREADS; + Py_INCREF(Py_None); + return Py_None; +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dim; + uint32_t tensorDim; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, + &elementSize, &desc_address)) { + return NULL; + } + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t boxDim[1] = {tensorDim}; + uint32_t elementStrides[1] = {1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + return NULL; + } + assert((elementSize * tensorDim) >= 32 && "block size too small."); + int rank = 1; + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +// Simple helper to experiment creating TMA descriptors on the host. +// This is a useful to test TMA operations independently. +static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else + unsigned long long global_address; + uint64_t dims[2]; + uint32_t tensorDims[2]; + int elementSize; + unsigned long long desc_address; + if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], + &tensorDims[1], &tensorDims[0], &elementSize, + &desc_address)) { + return NULL; + } + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); + } + int rank = 2; + // Swizzling should be picked in codegen but since we need to set it on the + // descriptor we rely on a convention between this function and codegen. + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + assert(false && "block size too small."); + } + // The bounding box inner dimension must be less than or equal to the swizzle + // size. + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 + // We clamp the block size and the codegen will emit multiple copy operations. + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); + CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( + (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + Py_INCREF(Py_None); + return Py_None; +#endif +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, + {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, + "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " + "controls how many bytes can be streamed from kernels before data starts " + "being dropped. This inherits all the limitations of this call; in " + "particular it's an error to change this value after launching any kernel " + "that calls printf()."}, + {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, + {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, + {"build_signature_metadata", buildSignatureMetadata, METH_VARARGS, + buildSignatureMetadata__doc__}, + {"launch", launch, METH_VARARGS, launch__doc__}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + + PyModule_AddFunctions(m, ModuleMethods); + + return m; +} diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c deleted file mode 100644 index 12deb0d1e7a3..000000000000 --- a/third_party/nvidia/backend/driver.c +++ /dev/null @@ -1,421 +0,0 @@ -#include "cuda.h" -#include -#include -#define PY_SSIZE_T_CLEAN -#include - -// Raises a Python exception and returns false if code is not CUDA_SUCCESS. -static bool gpuAssert(CUresult code, const char *file, int line) { - if (code == CUDA_SUCCESS) - return true; - - const char *prefix = "Triton Error [CUDA]: "; - const char *str; - cuGetErrorString(code, &str); - char err[1024] = {0}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - return false; -} - -// To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) \ - return NULL; \ - } while (0) - -// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block. -#define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \ - do { \ - if (!gpuAssert((ans), __FILE__, __LINE__)) { \ - PyEval_RestoreThread(_save); \ - return NULL; \ - } \ - } while (0) - -// Used to check if functions exist in old CUDA driver versions. -#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ - do { \ - if ((funcPointer) == NULL) { \ - (funcPointer) = (initializerFunction)(); \ - if ((funcPointer) == NULL) { \ - return NULL; \ - } \ - } \ - } while (0) - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - // Get device handle - CUdevice device; - cuDeviceGet(&device, device_id); - - // create a struct to hold device properties - int max_shared_mem; - int max_num_regs; - int multiprocessor_count; - int warp_size; - int sm_clock_rate; - int mem_clock_rate; - int mem_bus_width; - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); - CUDA_CHECK_AND_RETURN_NULL( - cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); - CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute( - &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); - - return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", - max_shared_mem, "max_num_regs", max_num_regs, - "multiprocessor_count", multiprocessor_count, "warpSize", - warp_size, "sm_clock_rate", sm_clock_rate, - "mem_clock_rate", mem_clock_rate, "mem_bus_width", - mem_bus_width); -} - -static PyObject *loadBinary(PyObject *self, PyObject *args) { - const char *name; - const char *data; - Py_ssize_t data_size; - int shared; - int device; - if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, - &device)) { - return NULL; - } - CUfunction fun; - CUmodule mod; - int32_t n_regs = 0; - int32_t n_spills = 0; - // create driver handles - CUcontext pctx = 0; - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx)); - if (!pctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx)); - } - - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuModuleGetFunction(&fun, mod, name)); - // get allocated registers and spilled registers from the function - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); - n_spills /= 4; - // set dynamic shared memory if necessary - int shared_optin; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - device)); - if (shared > 49152 && shared_optin > 49152) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); - int shared_total, shared_static; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute( - &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, - device)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute( - &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_optin - shared_static)); - } - Py_END_ALLOW_THREADS; - - if (PyErr_Occurred()) { - return NULL; - } - return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, - n_spills); -} - -typedef CUresult (*cuOccupancyMaxActiveClusters_t)( - int *numClusters, CUfunction func, const CUlaunchConfig *config); - -typedef CUresult (*cuTensorMapEncodeTiled_t)( - CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, - const cuuint64_t *globalStrides, const cuuint32_t *boxDim, - const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, - CUtensorMapFloatOOBfill oobFill); - -#define defineGetFunctionHandle(name, symbolName) \ - static symbolName##_t name() { \ - /* Open the shared library */ \ - void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \ - if (!libHandle) { \ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \ - return NULL; \ - } \ - /* Clear any existing error */ \ - dlerror(); \ - symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ - /* Check for errors */ \ - const char *err = dlerror(); \ - if (err) { \ - PyErr_SetString(PyExc_RuntimeError, \ - "Failed to retrieve " #symbolName " from libcuda.so.1"); \ - dlclose(libHandle); \ - return NULL; \ - } \ - return funcHandle; \ - } - -defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, - cuOccupancyMaxActiveClusters); - -defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, - cuTensorMapEncodeTiled); - -static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { - int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, - maxActiveClusters = -1; - int shared = 0; - CUfunction func; - - if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, - &clusterDimY, &clusterDimZ)) { - return NULL; - } - - // Let each SM have one block - int maxActiveBlocks = 1; - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); - Py_END_ALLOW_THREADS; - - CUlaunchAttribute launchAttr[1]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - CUlaunchConfig config; - config.gridDimX = clusterDimX; - config.gridDimY = maxActiveBlocks * clusterDimY; - config.gridDimZ = clusterDimZ; - config.blockDimX = 128; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared; - config.hStream = 0; - config.numAttrs = 1; - config.attrs = launchAttr; - - static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, - getCuOccupancyMaxActiveClustersHandle); - - Py_BEGIN_ALLOW_THREADS; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( - func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); - Py_END_ALLOW_THREADS; - return PyLong_FromLong(maxActiveClusters); -} - -static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { - long size; - if (!PyArg_ParseTuple(args, "l", &size)) { - return NULL; - } - if (size < 0) { - PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative"); - return NULL; - } - - Py_BEGIN_ALLOW_THREADS; - - // Ensure we have an active context. - CUcontext ctx = NULL; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx)); - if (!ctx) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuDevicePrimaryCtxRetain(&ctx, /*device=*/0)); - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx)); - } - - // We can't set the fifo size after running a kernel that calls printf. This - // is true even if the set() call is a nop and the new size is the same as the - // old size. - // - // This is unfriendly, so check if the old size matches the new size, and skip - // the set() call if so. - size_t oldSize = 0; - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE)); - if (oldSize != size) { - CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( - cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size)); - } - - Py_END_ALLOW_THREADS; - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dim; - uint32_t tensorDim; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKiiK", &global_address, &dim, &tensorDim, - &elementSize, &desc_address)) { - return NULL; - } - uint64_t dims[1] = {dim}; - uint64_t globalStrides[1] = {dim * elementSize}; - uint32_t boxDim[1] = {tensorDim}; - uint32_t elementStrides[1] = {1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - return NULL; - } - assert((elementSize * tensorDim) >= 32 && "block size too small."); - int rank = 1; - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -// Simple helper to experiment creating TMA descriptors on the host. -// This is a useful to test TMA operations independently. -static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { - unsigned long long global_address; - uint64_t dims[2]; - uint32_t tensorDims[2]; - int elementSize; - unsigned long long desc_address; - if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0], - &tensorDims[1], &tensorDims[0], &elementSize, - &desc_address)) { - return NULL; - } - uint64_t globalStrides[2] = {dims[0] * elementSize, - dims[0] * dims[1] * elementSize}; - uint32_t elementStrides[2] = {1, 1}; - CUtensorMapDataType type; - switch (elementSize) { - case 1: - type = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - case 2: - type = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 4: - type = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - default: - PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4"); - } - int rank = 2; - // Swizzling should be picked in codegen but since we need to set it on the - // descriptor we rely on a convention between this function and codegen. - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; - if (contigDimSizeInByte >= 128) { - swizzle = CU_TENSOR_MAP_SWIZZLE_128B; - } else if (contigDimSizeInByte >= 64) { - swizzle = CU_TENSOR_MAP_SWIZZLE_64B; - } else if (contigDimSizeInByte >= 32) { - swizzle = CU_TENSOR_MAP_SWIZZLE_32B; - } else { - assert(false && "block size too small."); - } - // The bounding box inner dimension must be less than or equal to the swizzle - // size. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 - // We clamp the block size and the codegen will emit multiple copy operations. - if (contigDimSizeInByte > 128) { - tensorDims[0] = 128 / elementSize; - } - static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, - getCuTensorMapEncodeTiledHandle); - CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled( - (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, - globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - Py_INCREF(Py_None); - return Py_None; -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBinary, METH_VARARGS, - "Load provided cubin into CUDA driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, - "Python interface for cuOccupancyMaxActiveClusters function"}, - {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS, - "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which " - "controls how many bytes can be streamed from kernels before data starts " - "being dropped. This inherits all the limitations of this call; in " - "particular it's an error to change this value after launching any kernel " - "that calls printf()."}, - {"fill_1d_tma_descriptor", fill1DTMADescriptor, METH_VARARGS, "doc"}, - {"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"}, - - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cuda_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - - PyModule_AddFunctions(m, ModuleMethods); - - return m; -} diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index d088ec0927da..2e98e52a377b 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,20 +1,15 @@ +from collections.abc import Callable import functools import os -import sysconfig -import hashlib import subprocess -import tempfile -from pathlib import Path -from triton.runtime.build import _build -from triton.runtime.cache import get_cache_manager from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from ._C import cuda_utils dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") -libraries = ['cuda'] @functools.lru_cache() @@ -47,26 +42,6 @@ def library_dirs(): return [libdevice_dir, *libcuda_dirs()] -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] - cache_path = cache.get_file(f"{name}.{ext}") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.c") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - # ------------------------ # Utils # ------------------------ @@ -80,13 +55,12 @@ def __new__(cls): return cls.instance def __init__(self): - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters - self.set_printf_fifo_size = mod.set_printf_fifo_size - self.fill_1d_tma_descriptor = mod.fill_1d_tma_descriptor - self.fill_2d_tma_descriptor = mod.fill_2d_tma_descriptor + self.load_binary = cuda_utils.load_binary + self.get_device_properties = cuda_utils.get_device_properties + self.cuOccupancyMaxActiveClusters = cuda_utils.cuOccupancyMaxActiveClusters + self.set_printf_fifo_size = cuda_utils.set_printf_fifo_size + self.fill_1d_tma_descriptor = cuda_utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor = cuda_utils.fill_2d_tma_descriptor # ------------------------ @@ -95,7 +69,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "CUdeviceptr" return { "i1": "int32_t", @@ -117,386 +91,82 @@ def ty_to_cpp(ty): }[ty] -def make_launcher(constants, signature): +def flatten_tuples(xs): + """Recursively flattens tuple elements in xs.""" + for x in xs: + if isinstance(x, tuple): + yield from flatten_tuples(x) + else: + yield x + + +def make_launcher(constants : dict[int, str], signature : dict[int, any]) -> Callable[..., None]: + # Here, signature can look like: + # {'_0': 'i32', + # 'Ptrs': (), + # '_1': 'constexpr', + # 'values': '[*f32, constexpr]', + # 'out_tuple': 'constexpr'} + # We want to remove the constexprs, flatten the tuples, and remove any more + # constexprs. If we remove them all at the end, we won't be able to remove + # entire tuples that are a single constexpr. If we remove them before + # flattening, we will miss mixed-tuples. So we do it twice. def _serialize_signature(sig): if isinstance(sig, tuple): return ','.join(map(_serialize_signature, sig)) return sig - - def _extracted_type(ty): - if isinstance(ty, tuple): - val = ','.join(map(_extracted_type, ty)) - return f"[{val}]" - if ty[0] == '*': - return "PyObject*" - if ty in ("constexpr", "nvTmaDesc"): - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - if isinstance(ty, tuple): - val = ''.join(map(format_of, ty)) - return f"({val})" - if ty[0] == '*': - return "O" - if ty in ("constexpr", "nvTmaDesc"): - return "O" - return { - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "L", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty_to_cpp(ty)] - - args_format = ''.join([format_of(ty) for ty in signature.values()]) - format = "iiiKKpOOOOO" + args_format + + # Remember & remove all the constexpr before flattening. + constant_indices_before_flattening = {i for i, [k, v] in enumerate(signature.items()) if v == 'constexpr'} + # constant_indices_before_flattening = [2, 4] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} + # signature = {'_0': 'i32', 'Ptrs': (), 'values': '[*f32, constexpr]'} + + # Flatten. signature = ','.join(map(_serialize_signature, signature.values())) + # signature = 'i32,,*f32,constexpr' signature = list(filter(bool, signature.split(','))) - signature = {i: s for i, s in enumerate(signature)} - args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") - internal_args_list = [] - for i, ty in signature.items(): - if ty[0] == "*": - internal_args_list.append(f"ptr_info{i}.dev_ptr") - elif ty == "nvTmaDesc": - # Note: we have to dereference the pointer - internal_args_list.append(f"*tma_ptr{i}") - elif ty != "constexpr": - internal_args_list.append(f"_arg{i}") - params = range(len(signature)) - - # generate glue code - newline = '\n ' - ptr_decls = [ - f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" - for i, ty in signature.items() - if ty[0] == "*" - ] - tma_decls = [ - f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items() - if ty == "nvTmaDesc" - ] - params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] - params.append("&global_scratch") - src = f""" -#include \"cuda.h\" -#include -#include -#include - -static inline void gpuAssert(CUresult code, const char *file, int line) -{{ - if (code != CUDA_SUCCESS) - {{ - const char* prefix = "Triton Error [CUDA]: "; - const char* str; - cuGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - PyGILState_STATE gil_state; - gil_state = PyGILState_Ensure(); - PyErr_SetString(PyExc_RuntimeError, err); - PyGILState_Release(gil_state); - }} -}} - -#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - -typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra); - -static cuLaunchKernelEx_t getLaunchKernelExHandle() {{ - // Open the shared library - void* handle = dlopen("libcuda.so.1", RTLD_LAZY); - if (!handle) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); - return NULL; - }} - // Clear any existing error - dlerror(); - cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) {{ - PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1"); - return NULL; - }} - return cuLaunchKernelExHandle; -}} - -static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - void *params[] = {{ {', '.join(params)} }}; - if (gridX*gridY*gridZ > 0) {{ - if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{ - CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); - }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{ - CUlaunchAttribute launchAttr[1]; - CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; - launchAttr[0] = coopAttr; - - CUlaunchConfig config; - config.gridDimX = gridX; - config.gridDimY = gridY; - config.gridDimZ = gridZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = 1; - - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - - }} else {{ - CUlaunchAttribute launchAttr[3]; - launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launchAttr[0].value.clusterDim.x = clusterDimX; - launchAttr[0].value.clusterDim.y = clusterDimY; - launchAttr[0].value.clusterDim.z = clusterDimZ; - launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; - launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; - - unsigned numAttrs = 2; - if (0 != launch_cooperative_grid) {{ - CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}}; - launchAttr[2] = coopAttr; - numAttrs = 3; - }} - - CUlaunchConfig config; - config.gridDimX = gridX * clusterDimX; - config.gridDimY = gridY * clusterDimY; - config.gridDimZ = gridZ * clusterDimZ; - config.blockDimX = 32 * num_warps; - config.blockDimY = 1; - config.blockDimZ = 1; - config.sharedMemBytes = shared_memory; - config.hStream = stream; - config.attrs = launchAttr; - config.numAttrs = numAttrs; - static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL; - if (cuLaunchKernelExHandle == NULL) {{ - cuLaunchKernelExHandle = getLaunchKernelExHandle(); - }} - CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0)); - }} - }} -}} - -typedef struct _DevicePtrInfo {{ - CUdeviceptr dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); - if(!ptr_info.dev_ptr) - return ptr_info; - uint64_t dev_ptr; - int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); - if (status == CUDA_ERROR_INVALID_VALUE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); - ptr_info.valid = false; - }} else if (status != CUDA_SUCCESS) {{ - CUDA_CHECK(status); // Catch any other cuda API errors - ptr_info.valid = false; - }} - ptr_info.dev_ptr = dev_ptr; - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static inline CUtensorMap* getTmaDesc(PyObject *obj) {{ - if (sizeof(CUtensorMap*) != 8) {{ - PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation"); - return NULL; - }} - - PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr"); - if (!method_handle) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist"); - return NULL; - }} - - PyObject *empty_tuple = PyTuple_New(0); - if (!empty_tuple) {{ - Py_DECREF(method_handle); - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(method_handle); - if (!method_ret) {{ - PyErr_SetString(PyExc_SystemError, "Internal Python error!"); - return NULL; - }} - - if (!PyLong_Check(method_ret)) {{ - PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int"); - Py_DECREF(method_ret); - return NULL; - }} - - uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret); - Py_DECREF(method_ret); - if (!ptr_as_uint) {{ - PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()"); - return NULL; - }} - if (ptr_as_uint % 64 != 0) {{ - PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned"); - return NULL; - }} - - return (CUtensorMap*)(ptr_as_uint); -}} - -static void ensureCudaContext() {{ - CUcontext pctx; - CUDA_CHECK(cuCtxGetCurrent(&pctx)); - if (!pctx) {{ - // Ensure device context. - CUdevice device; - CUDA_CHECK(cuDeviceGet(&device, 0)); - CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); - CUDA_CHECK(cuCtxSetCurrent(pctx)); - }} -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes - ensureCudaContext(); - - int gridX, gridY, gridZ; - uint64_t _stream; - uint64_t _function; - int launch_cooperative_grid; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *global_scratch_obj = NULL; - {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, - &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook{args_list})) {{ - return NULL; - }} - - int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ; - if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple"); - return NULL; - }} - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - CUdeviceptr global_scratch = 0; - if (global_scratch_obj != Py_None) {{ - DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1); - if (!global_scratch_info.valid) {{ - return NULL; - }} - global_scratch = global_scratch_info.dev_ptr; - }} - - // raise exception asap - {newline.join(ptr_decls)} - {newline.join(tma_decls)} - Py_BEGIN_ALLOW_THREADS; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); - Py_END_ALLOW_THREADS; - if (PyErr_Occurred()) {{ - return NULL; - }} - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - - }} - - Py_RETURN_NONE; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src + # signature = ['i32', '*f32', 'constexpr'] + + # Remove any constexprs after flattening. + constant_indices_after_flattening = {i for i, s in enumerate(signature) if s == 'constexpr'} + # constant_indices_after_flattening = [2] + signature = {i: s for i, s in enumerate(signature) if s != 'constexpr'} + # signature = {0: 'i32', 1: '*f32'} + + signature_metadata = cuda_utils.build_signature_metadata( + ty for ty in signature.values()) + + def wrapper(grid_dim_x: int, grid_dim_y: int, grid_dim_z: int, + stream: int, kernel: int, global_scratch: any, + packed_metadata: tuple[int, int, int, int, int, int], + hook_args: any, + launch_enter_hook: Callable[..., None], + launch_exit_hook: Callable[..., None], + *args: any) -> None: + # Given the example above, args would look something like: + # args = [8, (), 5, (3, 4), (2, 2, 2)] + # constant_indices_before_flattening = [2, 4] + # Remove constantexprs before flattening: + non_const_args = [arg + for idx, arg in enumerate(args) + if idx not in constant_indices_before_flattening + ] + # non_const_args = [8, (), (3, 4)] + non_const_args = flatten_tuples(non_const_args) + # non_const_args = [8, 3, 4] + # constant_indices_after_flattening = [2] + non_const_args = [arg + for idx, arg in enumerate(non_const_args) + if idx not in constant_indices_after_flattening + ] + # non_const_args = [8, 3] + cuda_utils.launch(grid_dim_x, grid_dim_y, grid_dim_z, stream, kernel, + packed_metadata, hook_args, launch_enter_hook, + launch_exit_hook, signature_metadata, global_scratch, + non_const_args) + return wrapper class CudaLauncher(object): @@ -506,9 +176,7 @@ def __init__(self, src, metadata): arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x constants = {arg_idx(idx): value for idx, value in constants.items()} signature = {idx: value for idx, value in src.signature.items()} - src = make_launcher(constants, signature) - mod = compile_module_from_src(src, "__triton_launcher") - self.launch = mod.launch + self.launch = make_launcher(constants, signature) self.global_scratch_size = metadata.global_scratch_size self.global_scratch_align = metadata.global_scratch_align self.launch_cooperative_grid = metadata.launch_cooperative_grid @@ -520,7 +188,7 @@ def __call__(self, gridX, gridY, gridZ, stream, function, *args): global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream) else: global_scratch = None - self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args) + self.launch(gridX, gridY, gridZ, stream, function, global_scratch, *args) class CudaDriver(GPUDriver): @@ -551,7 +219,7 @@ def is_active(): import torch return torch.cuda.is_available() and (torch.version.hip is None) except ImportError: - return False + return True def get_benchmarker(self): from triton.testing import do_bench diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 458913dba595..e31317b05d00 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -113,6 +113,15 @@ def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } +def NVGPU_SparseWGMMAOp : NVGPU_Op<"wgmma_sp", []> { + let arguments = (ins WGMMA_OperandType:$opA, I32:$metaA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, + I32Attr:$m, I32Attr:$n, I32Attr:$k, + WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, + WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); + let results = (outs LLVM_AnyStruct:$res); + let assemblyFormat = "$opA `meta` $metaA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; +} + def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { let arguments = (ins BoolAttr:$bCluster); let assemblyFormat = "attr-dict"; diff --git a/third_party/nvidia/language/cuda/BUILD b/third_party/nvidia/language/cuda/BUILD new file mode 100644 index 000000000000..55e6ec8795c1 --- /dev/null +++ b/third_party/nvidia/language/cuda/BUILD @@ -0,0 +1,13 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index ae042d0dad0a..164854a70913 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -442,10 +442,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp index 449c8f50de15..2ac8604a5c84 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp @@ -3,6 +3,7 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "third_party/triton/include/triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" using namespace mlir; @@ -58,8 +59,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 }; inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB, Type scaleAType, Type scaleBType) { if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) { - if (llvm::isa(scaleAType) && - llvm::isa(scaleBType)) { + if (llvm::isa(scaleAType) && + llvm::isa(scaleBType)) { return mxfpKind::mxf4nvf4; } return mxfpKind::mxf4; @@ -100,10 +101,11 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter, return 1; if (type.isF32()) return 2; - if (llvm::isa(type)) + if (llvm::isa(type)) return 0; - if (llvm::isa(type)) + if (llvm::isa(type)) return 1; + llvm_unreachable("Unsupported type."); }; static_assert(sizeof(TCGen5InstructionDescriptor) == 4, @@ -225,7 +227,8 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc, opcode += "f16"; else if (srcElementTy.isF32()) opcode += "tf32"; - else if (llvm::isa(srcElementTy)) + else if (llvm::isa(srcElementTy) || + llvm::isa(srcElementTy)) opcode += "f8f6f4"; else assert(0 && "Unsupported type."); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index d1f613db7114..c9f09f6b5827 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -79,7 +79,7 @@ int64_t getSwizzlingFromLayout(const NVMMASharedEncodingAttr &layout, return swizzlingByteWidth; } -static Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, +Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, int64_t swizzling, uint32_t stride) { auto b = TritonLLVMOpBuilder(loc, rewriter); static_assert(sizeof(SMEMDescriptor) == 8, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 4f1c36236a6e..69fe3530e561 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -286,7 +286,8 @@ static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, const SmallVector &v) -> SmallVector { auto b = TritonLLVMOpBuilder(loc, rewriter); int numElements = v.size(); - assert(numElements == 4 || numElements == 2 && "invalid vector size"); + assert(numElements == 8 || numElements == 4 || + numElements == 2 && "invalid vector size"); auto ctx = rewriter.getContext(); int inBitwidth = inType.getIntOrFloatBitWidth(); @@ -583,6 +584,114 @@ struct SIToFPOpConversion : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), computeCapability(computeCapability) {} + LogicalResult matchAndRewrite( + arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (succeeded(matchAndRewriteInt4ToBf16Conversion(op, rewriter))) { + return success(); + } + return Base::matchAndRewrite(op, adaptor, rewriter); + } + + // Matches subgraph of convert 8xi4 to 8xbf16 and rewrites it to inline PTX. + LogicalResult matchAndRewriteInt4ToBf16Conversion( + arith::SIToFPOp op, ConversionPatternRewriter &rewriter) const { + if (computeCapability < 90) return failure(); + Type inElemTy = getElementType(op.getIn()); + Type outElemTy = getElementType(op.getOut()); + if (!inElemTy.isInteger(8) || !outElemTy.isBF16()) return failure(); + FailureOr unpack = matchInt4Unpack(op.getIn()); + if (failed(unpack)) return failure(); + + Location loc = op.getLoc(); + Value src = rewriter.getRemappedValue(unpack.value()); + auto structTy = dyn_cast(src.getType()); + if (!structTy || structTy.getBody().size() % 4 != 0) return failure(); + auto isInt8 = [](Type type) { return type.isInteger(8); }; + if (!all_of(structTy.getBody(), isInt8)) return failure(); + + const LLVMTypeConverter *typeConverter = getTypeConverter(); + assert(inElemTy == typeConverter->convertType(inElemTy)); + assert(outElemTy == typeConverter->convertType(outElemTy)); + + const std::string S4_to_Bf16_sm90 = R"({ + .reg .b32 r<4>, mi, mf; + mov.b32 mi, 0x43404340 - 0x00080008; + mov.b32 mf, 0x43404340; + // Shift 4-bit inputs to 16-bit boundary. + shr.u32 r1, $4, 4; + shr.u32 r2, $4, 8; + shr.u32 r3, $4, 12; + // Sign-extend from 4 bits is equivalent to (x ^ 0x8) - 0x8. + lop3.b32 r0, $4, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r1, r1, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r2, r2, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + lop3.b32 r3, r3, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; + // Interger-add magic number (minus bias from sign-extend above). + add.s16x2 r0, r0, mi; + add.s16x2 r1, r1, mi; + add.s16x2 r2, r2, mi; + add.s16x2 r3, r3, mi; + // Float-subtract magic number. + sub.bf16x2 r0, r0, mf; + sub.bf16x2 r1, r1, mf; + sub.bf16x2 r2, r2, mf; + sub.bf16x2 r3, r3, mf; + // Shuffle results into correct order. + prmt.b32 $0, r1, r0, 0x5410; + prmt.b32 $1, r3, r2, 0x5410; + prmt.b32 $2, r1, r0, 0x7632; + prmt.b32 $3, r3, r2, 0x7632; + })"; + + SmallVector resultVals; + SmallVector unpackedVals = unpackLLElements(loc, src, rewriter); + auto cvtFunc = makeConverterFromPtx(S4_to_Bf16_sm90, inElemTy, outElemTy); + for (ValueRange operands = unpackedVals; !operands.empty(); + operands = operands.drop_front(4)) { + SmallVector inVals = { + operands[0], operands[1], operands[2], operands[3], + // Repeat operands so that cvtFunc produces 8 outputs. + operands[0], operands[1], operands[2], operands[3]}; + auto outVals = cvtFunc(loc, rewriter, inVals); + assert(inVals.size() == outVals.size()); + resultVals.append(outVals.begin(), outVals.end()); + } + + resultVals = maybeDeduplicate(op, resultVals); + Value view = + packLLElements(loc, typeConverter, resultVals, rewriter, op.getType()); + rewriter.replaceOp(op, view); + + return success(); + } + + // Returns the source if value is the result of an 2xi4 -> 2xi8 unpack + // sequence. + static FailureOr matchInt4Unpack(Value value) { + auto reshape = value.getDefiningOp(); + if (!reshape) return failure(); + auto join = reshape.getSrc().getDefiningOp(); + if (!join) return failure(); + auto shrHi = join.getLhs().getDefiningOp(); + if (!shrHi || !isConst4(shrHi.getRhs())) return failure(); + auto shrLo = join.getRhs().getDefiningOp(); + if (!shrLo || !isConst4(shrLo.getRhs())) return failure(); + auto shlLo = shrLo.getLhs().getDefiningOp(); + if (!shlLo || !isConst4(shlLo.getRhs())) return failure(); + if (shrHi.getLhs() != shlLo.getLhs()) return failure(); + return shrHi.getLhs(); + } + + // Returns true if the value is equal to 4. + static bool isConst4(Value v) { + auto constOp = v.getDefiningOp(); + if (!constOp) return false; + auto attr = mlir::dyn_cast(constOp.getValue()); + if (!attr || !attr.isSplat()) return false; + return attr.getSplatValue().getLimitedValue() == 4; + }; + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, diff --git a/third_party/proton/BUILD b/third_party/proton/BUILD new file mode 100644 index 000000000000..783718497934 --- /dev/null +++ b/third_party/proton/BUILD @@ -0,0 +1,130 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +td_library( + name = "td_files", + srcs = glob(["dialect/include/Dialect/Proton/IR/*.td"]), + includes = ["dialect/include"], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "//:td_files", + ], +) + +gentbl_cc_library( + name = "proton_ir_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "dialect/include/Dialect/Proton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "dialect/include/Dialect/Proton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "proton_ir_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "dialect/include/Dialect/Proton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "dialect/include/Dialect/Proton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "dialect/include/Dialect/Proton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "dialect/include/Dialect/Proton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "dialect/include/Dialect/Proton/IR/ProtonOps.td", + deps = ["td_files"], +) + +cc_library( + name = "ProtonIRDialect", + srcs = glob([ + "dialect/lib/Dialect/Proton/IR/*.cpp", + ]), + hdrs = glob([ + "dialect/include/Dialect/Proton/IR/*.h", + ]), + includes = [ + "..", # because proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc + "dialect/include", + ], + deps = [ + ":proton_ir_attr_inc_gen", + ":proton_ir_dialect_inc_gen", + ":proton_ir_ops_inc_gen", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonProtonToLLVM", + srcs = glob([ + "dialect/lib/TritonProtonToLLVM/*.cpp", + ]), + hdrs = glob([ + "dialect/include/TritonProtonToLLVM/*.h", + ]), + includes = [ + ], + deps = [ + ":ProtonIRDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..4cbadcfa4655 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +)