Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Moerafaat committed Feb 21, 2024
1 parent b1334eb commit d2e4e86
Show file tree
Hide file tree
Showing 26 changed files with 1,576 additions and 38 deletions.
908 changes: 908 additions & 0 deletions BUILD

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,7 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
return {32, parentShapePerCTA[1]};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
}

Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using ttg::SliceEncodingAttr;
// supported
static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
int baseVersion = 0;
if (computeCapability < 75) {
if (computeCapability < 80) {
baseVersion = 1;
} else if (computeCapability < 90) {
baseVersion = 2;
Expand Down Expand Up @@ -305,8 +305,10 @@ class BlockedToMMA : public mlir::RewritePattern {
} else {

// convert operands
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
// TODO(b/296812125): Fix minBitwidth issue upstream and uncomment.
// int minBitwidth =
// std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
int minBitwidth = 0;
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = ttg::DotOperandEncodingAttr::get(
Expand Down
16 changes: 15 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <memory>
#include <string>

inline bool isPipeliningEnabled() {
const char *s = std::getenv("ENABLE_PIPELINING");
std::string str(s ? s : "");
std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c) { return std::tolower(c); });
return (str == "on" || str == "true" || str == "1");
}

namespace {

Expand Down Expand Up @@ -335,7 +347,9 @@ class TritonGPUOptimizeDotOperandsPass

mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
// TODO(b/291216607): Fix crashes and enable by default.
if (isPipeliningEnabled() &&
triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80)
patterns.add<HoistLayoutConversion>(context);
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct FenceInsertionPass
// Only insert fences for compute capability 9.0
if (computeCapability < 90)
return;
if (::triton::tools::getBoolEnv("DISABLE_MMA_V3"))
// TODO(b/311157761): enable mma_v3
if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3"))
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
Expand Down
90 changes: 90 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# NOTE: Do not depend on any targets from this directory,
# but use //third_party/py/triton instead.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = ["//:license"],
default_visibility = [
"//third_party/py/triton:__pkg__",
"//third_party/triton/python:__subpackages__",
],
)

cc_library(
name = "passes",
hdrs = ["src/passes.h"],
includes = ["src"],
visibility = ["//third_party/triton/third_party:__subpackages__"],
)

pybind_extension(
name = "libtriton",
srcs = [
"src/interpreter.cc",
"src/ir.cc",
"src/llvm.cc",
"src/main.cc",
"src/passes.cc",
],
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
deps = [
":passes",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:InstCombine",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:MC",
"@llvm-project//llvm:Passes",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:IndexDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"//:TritonAnalysis",
"//:TritonDialects",
"//:TritonGPUTransforms",
"//:TritonHSACO",
"//:TritonLLVMIR",
"//:TritonNvidiaGPUTransforms",
"//:TritonPTX",
"//:TritonToTritonGPU",
"//:TritonTools",
"//:TritonTransforms",
"//third_party/triton/third_party/nvidia:triton_nvidia",
],
)

pybind_extension(
name = "triton_launcher",
srcs = [
"triton/compiler/triton_launcher.c",
],
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cuda_runtime",
],
)

filegroup(
name = "files",
srcs = glob(
include = ["triton/**/*.py"],
),
)
50 changes: 30 additions & 20 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1423,26 +1423,36 @@ void init_triton_ir(py::module &&m) {
.def("enable_debug",
[](PassManager &self) {
auto *context = self.getContext();
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->disableMultithreading();
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});

if (!triton::tools::getBoolEnv("MLIR_ENABLE_DUMP"))
return;
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
auto print_always = [](Pass *, Operation *) { return true; };
self.enableIRPrinting(
/*shouldPrintBeforePass=*/print_always,
/*shouldPrintAfterPass=*/print_always,
/*printModuleScope=*/true,
/*printAfterOnlyOnChange=*/false,
/*printAfterOnlyOnFailure*/ true, llvm::dbgs(), printingFlags);
bool have_diagnostics =
triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
bool have_dump = triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
if (have_diagnostics || have_dump) {
context->disableMultithreading();
}
if (have_diagnostics) {
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->getDiagEngine().registerHandler(
[](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});
}
if (have_dump) {
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
auto print_always = [](Pass *, Operation *) {
return true;
};
self.enableIRPrinting(
/*shouldPrintBeforePass=*/print_always,
/*shouldPrintAfterPass=*/print_always,
/*printModuleScope=*/true,
/*printAfterOnlyOnChange=*/false,
/*printAfterOnlyOnFailure*/ true, llvm::dbgs(),
printingFlags);
}
})
.def("run", [](PassManager &self, ModuleOp &mod) {
// TODO: maybe dump module to file and print error for better
Expand Down
27 changes: 27 additions & 0 deletions python/test/regression/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["test_*.py"],

#TODO(b/321005767): fix failing test
exclude = [
"test_performance.py",
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
21 changes: 21 additions & 0 deletions python/test/unit/hopper/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["**/test_*.py"],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
30 changes: 30 additions & 0 deletions python/test/unit/language/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
srcs = [
"conftest.py",
"test_core.py",
],
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
include = ["**/test_*.py"],
exclude = [
"test_subprocess.py", # TODO(b/320224484): fix failing test
"test_reproducer.py", # this is not an actual test, but a tool for running reproducers
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
24 changes: 24 additions & 0 deletions python/test/unit/operators/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests = glob(
[
"**/test_*.py",
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
24 changes: 24 additions & 0 deletions python/test/unit/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests =
glob(
include = ["**/test_*.py"],
exclude = [
"test_launch.py", #TODO(b/320226169): fix failing tests
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
26 changes: 26 additions & 0 deletions python/test/unit/tools/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests")

package(
default_applicable_licenses = ["//:license"],
)

pytest_multi_tests(
name = "tests",
size = "large",
shard_count = 10,
tags = [
"config-cuda-only",
"requires-gpu-sm80",
],
tests =
glob(
include = ["**/test_*.py"],
exclude = [
"test_aot.py", # TODO(b/320224484): fix failing test
],
),
deps = [
"//third_party/py/torch:pytorch",
"//third_party/py/triton",
],
)
2 changes: 1 addition & 1 deletion python/triton/_C/include
Loading

0 comments on commit d2e4e86

Please sign in to comment.