diff --git a/BUILD b/BUILD new file mode 100644 index 0000000000000..4c0d8acb0f7c7 --- /dev/null +++ b/BUILD @@ -0,0 +1,908 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:gce"], + # default_visibility = [ + # "//third_party/py/jax:__subpackages__", + # "//third_party/tensorflow/compiler/xla:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/triton/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/AxisInfo.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/AxisInfo.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + "lib/Conversion/TritonGPUToLLVM/Utility.h", + ], + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + ]) + [ + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/NVGPU/IR/*.h", + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + ]) + [ + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvgpu_attr_inc_gen", + ":triton_nvgpu_dialect_inc_gen", + ":triton_nvgpu_ops_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]) + [ + "lib/Conversion/TritonGPUToLLVM/TypeConverter.h", + ], + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + includes = ["include"], + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + includes = ["include"], + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + includes = ["include"], + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], + includes = ["include"], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/RegisterTritonDialects.h", + "bin/triton-opt.cpp", + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//third_party/triton/test:TritonTestAnalysis", + "//third_party/triton/third_party/nvidia:NVGPUToLLVM", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + includes = [ + ".", # because it includes "lib/Target/LLVMIR/LLVMPasses.h" + "include", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = [ + "bin/RegisterTritonDialects.h", + "bin/triton-reduce.cpp", + ], + includes = [ + "include", + ], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@llvm-project//mlir:NVVMDialect", + "//third_party/triton/test:TritonTestAnalysis", + "//third_party/triton/third_party/nvidia:NVGPUToLLVM", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 3650bdae51112..f82e3fbbfdff4 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -483,7 +483,8 @@ bool supportMMA(triton::DotOp op, int version) { auto aElemTy = op.getA().getType().cast().getElementType(); auto bElemTy = op.getB().getType().cast().getElementType(); if (version == 3) { - if (triton::tools::getBoolEnv("DISABLE_MMA_V3")) + // TODO(b/311157761): enable mma_v3 + if (!triton::tools::getBoolEnv("ENABLE_MMA_V3")) return false; auto retType = op.getType(); auto retShapePerCTA = getShapePerCTA(retType); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 67665b7b14837..c6752f2401ce2 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1433,6 +1433,7 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, return {32, parentShapePerCTA[1]}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); + return {}; } } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 42e7258308b2f..6d7073f7386a3 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -24,7 +24,7 @@ using ttg::SliceEncodingAttr; // supported static int getMMAVersionSafe(int computeCapability, tt::DotOp op) { int baseVersion = 0; - if (computeCapability < 75) { + if (computeCapability < 80) { baseVersion = 1; } else if (computeCapability < 90) { baseVersion = 2; @@ -305,8 +305,10 @@ class BlockedToMMA : public mlir::RewritePattern { } else { // convert operands - int minBitwidth = - std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + // TODO(b/296812125): Fix minBitwidth issue upstream and uncomment. + // int minBitwidth = + // std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)); + int minBitwidth = 0; Type minType = IntegerType::get(ctx, minBitwidth); // convert A operand auto newAEncoding = ttg::DotOperandEncodingAttr::get( diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 324d80abbe052..c16384951528c 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -7,7 +7,19 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include +#include +#include #include +#include + +inline bool isPipeliningEnabled() { + const char *s = std::getenv("ENABLE_PIPELINING"); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} namespace { @@ -335,7 +347,9 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); - if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) + // TODO(b/291216607): Fix crashes and enable by default. + if (isPipeliningEnabled() && + triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 06eeba79f83fa..30c5dc4ed0d29 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -40,7 +40,8 @@ struct FenceInsertionPass // Only insert fences for compute capability 9.0 if (computeCapability < 90) return; - if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) + // TODO(b/311157761): enable mma_v3 + if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 0000000000000..5245616141f97 --- /dev/null +++ b/python/BUILD @@ -0,0 +1,90 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "//third_party/triton/python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["//third_party/triton/third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "//third_party/triton/third_party/nvidia:triton_nvidia", + ], +) + +pybind_extension( + name = "triton_launcher", + srcs = [ + "triton/compiler/triton_launcher.c", + ], + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/src/ir.cc b/python/src/ir.cc index c5be3cebc2212..f07ce4ccf93d3 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1423,26 +1423,36 @@ void init_triton_ir(py::module &&m) { .def("enable_debug", [](PassManager &self) { auto *context = self.getContext(); - context->printOpOnDiagnostic(true); - context->printStackTraceOnDiagnostic(true); - context->disableMultithreading(); - context->getDiagEngine().registerHandler([](Diagnostic &diag) { - llvm::outs() << diag << "\n"; - return success(); - }); - - if (!triton::tools::getBoolEnv("MLIR_ENABLE_DUMP")) - return; - auto printingFlags = OpPrintingFlags(); - printingFlags.elideLargeElementsAttrs(16); - printingFlags.enableDebugInfo(); - auto print_always = [](Pass *, Operation *) { return true; }; - self.enableIRPrinting( - /*shouldPrintBeforePass=*/print_always, - /*shouldPrintAfterPass=*/print_always, - /*printModuleScope=*/true, - /*printAfterOnlyOnChange=*/false, - /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), printingFlags); + bool have_diagnostics = + triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS"); + bool have_dump = triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"); + if (have_diagnostics || have_dump) { + context->disableMultithreading(); + } + if (have_diagnostics) { + context->printOpOnDiagnostic(true); + context->printStackTraceOnDiagnostic(true); + context->getDiagEngine().registerHandler( + [](Diagnostic &diag) { + llvm::outs() << diag << "\n"; + return success(); + }); + } + if (have_dump) { + auto printingFlags = OpPrintingFlags(); + printingFlags.elideLargeElementsAttrs(16); + printingFlags.enableDebugInfo(); + auto print_always = [](Pass *, Operation *) { + return true; + }; + self.enableIRPrinting( + /*shouldPrintBeforePass=*/print_always, + /*shouldPrintAfterPass=*/print_always, + /*printModuleScope=*/true, + /*printAfterOnlyOnChange=*/false, + /*printAfterOnlyOnFailure*/ true, llvm::dbgs(), + printingFlags); + } }) .def("run", [](PassManager &self, ModuleOp &mod) { // TODO: maybe dump module to file and print error for better diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 0000000000000..b6a3534474d1d --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,27 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + + #TODO(b/321005767): fix failing test + exclude = [ + "test_performance.py", + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/hopper/BUILD b/python/test/unit/hopper/BUILD new file mode 100644 index 0000000000000..66623e09476fd --- /dev/null +++ b/python/test/unit/hopper/BUILD @@ -0,0 +1,21 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["**/test_*.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/BUILD b/python/test/unit/language/BUILD new file mode 100644 index 0000000000000..3db8ea728cc11 --- /dev/null +++ b/python/test/unit/language/BUILD @@ -0,0 +1,30 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = [ + "conftest.py", + "test_core.py", + ], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["**/test_*.py"], + exclude = [ + "test_subprocess.py", # TODO(b/320224484): fix failing test + "test_reproducer.py", # this is not an actual test, but a tool for running reproducers + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/operators/BUILD b/python/test/unit/operators/BUILD new file mode 100644 index 0000000000000..34dbc4f30864d --- /dev/null +++ b/python/test/unit/operators/BUILD @@ -0,0 +1,24 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + [ + "**/test_*.py", + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/runtime/BUILD b/python/test/unit/runtime/BUILD new file mode 100644 index 0000000000000..4d4884d0c67c8 --- /dev/null +++ b/python/test/unit/runtime/BUILD @@ -0,0 +1,24 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = + glob( + include = ["**/test_*.py"], + exclude = [ + "test_launch.py", #TODO(b/320226169): fix failing tests + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/tools/BUILD b/python/test/unit/tools/BUILD new file mode 100644 index 0000000000000..5587dd2d10105 --- /dev/null +++ b/python/test/unit/tools/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = + glob( + include = ["**/test_*.py"], + exclude = [ + "test_aot.py", # TODO(b/320224484): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1b..8a5dba6c4b560 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index fbf65d9e908fd..5d8fb01b1191d 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 512e7ef08c29c..ec66cc72d2f25 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -27,7 +27,7 @@ def byte_perm(arg0, arg1, arg2, _builder=None): @core.extern -def min(arg0, arg1, propagate_nan: core.constexpr = core.PropagateNan.NONE, _builder=None): +def min(arg0, arg1, propagate_nan: core.constexpr = core.constexpr(core.PropagateNan.NONE), _builder=None): arg0 = core._to_tensor(arg0, _builder) arg1 = core._to_tensor(arg1, _builder) arg0 = core._promote_bfloat16_to_float32(arg0, _builder=_builder) @@ -50,7 +50,7 @@ def min(arg0, arg1, propagate_nan: core.constexpr = core.PropagateNan.NONE, _bui @core.extern -def max(arg0, arg1, propagate_nan: core.constexpr = core.PropagateNan.NONE, _builder=None): +def max(arg0, arg1, propagate_nan: core.constexpr = core.constexpr(core.PropagateNan.NONE), _builder=None): arg0 = core._to_tensor(arg0, _builder) arg1 = core._to_tensor(arg1, _builder) arg0 = core._promote_bfloat16_to_float32(arg0, _builder=_builder) diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 0000000000000..ca39a542eccd7 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:gce"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# # These require adjusted RUN commands for python internally. +# "Target/tritongpu_to_llvmir_noinline.mlir", +# "Target/tritongpu_to_llvmir.mlir", +# "Target/tritongpu_to_ptx.mlir", +# # TODO(b/283035396): broken by cl536931041.patch +# "TritonGPU/dot-operands.mlir", +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + ], +) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index aa21bf6d98fe4..df5681ce423f9 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --decompose-unsupported-conversions --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --decompose-unsupported-conversions --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> #shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 98adc1aaf617a..a229e54612f1f 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -1,5 +1,5 @@ -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s -// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FileCheck %s --check-prefix=CHECK-80 +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=80 | FILECHECK_OPTS= FileCheck %s --check-prefix=CHECK-80 // CHECK: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> // CHECK: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index ce453b1880511..a77456c36e671 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s +// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 0000000000000..e9ad78225ca3e --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,152 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + # copybara:uncomment_begin + # visibility = [ + # "//third_party/triton/python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + compatible_with = ["//buildenv/target:gce"], + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + ":TritonNVIDIAGPUToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + compatible_with = ["//buildenv/target:gce"], + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + compatible_with = ["//buildenv/target:gce"], + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + compatible_with = ["//buildenv/target:gce"], + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 0000000000000..9bd0230c572e8 --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,28 @@ +load("//third_party/bazel_rules/rules_python/python:py_extension.bzl", "py_extension") + +package( + default_applicable_licenses = ["//:license"], +) + +py_extension( + name = "cuda_utils", + srcs = ["driver.c"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "@local_config_cuda//cuda:cuda_headers", + "//third_party/python_runtime:headers", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["*.py"], + ), + visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index a201febcd52f7..4973de767c075 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -560,6 +560,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), axisAnalysisPass(axisAnalysisPass) {} + // True if elements allocated to a thread are contiguous within the axis. This + // is not the case in MMA-like encodings wherea thread might have elements + // (0,0),(0,1) and (8,0),(8,1) for example. The problem with this is that the + // deduplication mechanism assumes that for example constancy=4 and + // elements/thread=4 that if a thread has all elements constant. + bool contiguouslyMapped(Attribute encoding) const { + if (auto slice = encoding.dyn_cast()) { + return contiguouslyMapped(slice.getParent()); + } + return encoding.isa(); + } + // Try to deduplicate the resultVals based on the // constancy properties of the result discovered by // the axis analysis pass. If possible, redundant @@ -585,8 +597,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { if (!encoding) // encoding not available return resultVals; - if (!encoding.dyn_cast() && - !encoding.dyn_cast()) { + if (!contiguouslyMapped(encoding)) { // TODO: constraining the ecndoing type here is necessary for avoiding // crashes in the getElemsPerThread call below happening in the // test_core::test_fp8_dot_acc diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 2045ba1780adb..4854753069b04 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "NVGPUToLLVM/Passes.h" +#include "NVGPUToLLVM/Passes.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 0000000000000..33c8c504d1e75 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,102 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:gce"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = glob([ + "Dialect/**/*.cpp", + ]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:AsmParser", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_test( + name = "EmitIndicesTest", + srcs = [ + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-private-header", + ], + }), + includes = [ + "Conversion/TritonGPUToLLVM", + ], + # We want this to be buildable to update LLVM, but it doesn't pass and never has, even in OSS: + # https://github.com/openai/triton/blob/ded624282e67e5f58db332380e6ff088f276d534/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp#L677 + tags = [ + "manual", + "notap", + ], + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "//third_party/triton/third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ":EmitIndicesTest", + ], +)