Skip to content

Commit

Permalink
OpenXLA-specific changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg committed Feb 24, 2025
1 parent 5aa4af9 commit f6f77a1
Show file tree
Hide file tree
Showing 50 changed files with 3,804 additions and 1,043 deletions.
928 changes: 928 additions & 0 deletions BUILD

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,17 @@ We call each individual tile "rep".
"unsigned",
"getTotalElemsPerThread",
(ins "ArrayRef<int64_t>":$shape),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
return toLinearEncoding($_attr, shape).getTotalElemsPerThread(shape);
}]>,
InterfaceMethod<"Return element size per thread in each dimension.",
"SmallVector<unsigned>",
"getElemsPerThread",
(ins "ArrayRef<int64_t>":$shape),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
return toLinearEncoding($_attr, shape).getElemsPerThread(shape);
}]>,
// Interface for the meta information about the multiple thread hierarchy.
InterfaceMethod<"Get the shape of the warps per CTA.",
Expand Down
12 changes: 12 additions & 0 deletions lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
addArgumentMaterialization([&](OpBuilder &builder,
RankedTensorType tensorType, ValueRange inputs,
Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining arguments that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Argument rematerialization should not happen in Triton "
"-> TritonGPU conversion");
return {};
Expand All @@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// convert origValue to newValue
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
// remaining uses of values that have been converted to a new type.
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
// 'convert-triton-to-tritongpu'.
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
inputs);
llvm_unreachable("Source rematerialization should not happen in Triton -> "
"TritonGPU Conversion");
return {};
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
}

unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getTotalElemsPerThread(shape);
}
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape);
}

SmallVector<unsigned> getElemsPerThread(Attribute layout,
ArrayRef<int64_t> shape) {
if (auto distLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
return distLayout.getElemsPerThread(shape);
}
return toLinearEncoding(layout, shape).getElemsPerThread(shape);
}

Expand Down
9 changes: 7 additions & 2 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down Expand Up @@ -221,8 +226,8 @@ struct CanonicalizeConvertFromConvert
// heuristic to accommodate fused attention.
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();

Operation *arg = op.getSrc().getDefiningOp();
Expand Down
47 changes: 39 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ namespace mlir {
namespace triton {
namespace gpu {

namespace {

// Get the highest version supported for the hardware and the dot.
static int getMMAVersionSafe(int computeCapability, DotOp op) {
// List supported mma version in order of preference.
Expand All @@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
return 0;
}

SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
SmallVector<unsigned>
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
Expand Down Expand Up @@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
}

SmallVector<unsigned, 2>
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
mlir::getForwardSlice(dotOp->getResult(0), &slices);
// Contains a chained dot. We prefer to assign warps to one axis
// to facilitate use cases like flash attention, allowing reductions within
// the same warp.
Expand Down Expand Up @@ -181,6 +179,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -204,7 +217,7 @@ getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
}

SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
Expand All @@ -218,6 +231,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
}

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}

return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand All @@ -237,7 +260,7 @@ static bool bwdFilter(Operation *op) {
// result, kwidth can be the bitwidth of the lower precision primitive.
// Conversely, in the downcasting scenario, no reordering is performed,
// making it directory use the lower precision primitive.
static int computeOrigBitWidth(Value x) {
int computeOrigBitWidth(Value x) {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
Expand All @@ -257,6 +280,9 @@ static int computeOrigBitWidth(Value x) {
}
return origBitWidth;
}
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
// extension.
namespace {

class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
int computeCapability;
Expand Down Expand Up @@ -1147,6 +1173,11 @@ class TritonGPUAccelerateMatmulPass
}
};

Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
int opIdx, bool allowTranspose) {
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,

Value zero = builder.createWithStage<arith::ConstantIntOp>(
forOp.getLoc(), stage, clusterId, 0, 32);

// Replace the load with insert/extract slice.
builder.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
Expand Down Expand Up @@ -524,7 +525,8 @@ assignMemoryLayouts(scf::ForOp &forOp,

bool isTMALoad = isa<tt::ExperimentalDescriptorLoadOp,
tt::ExperimentalDescriptorGatherOp>(op);
loadsToPipeline.insert(&op);
// TODO: b/381421713 - Uncomment this once pipelining is fixed.
// loadsToPipeline.insert(&op);
LoadInfo loadInfo;
for (auto use : users) {
if (isa<mlir::triton::DotOpInterface>(use)) {
Expand Down Expand Up @@ -562,6 +564,11 @@ assignMemoryLayouts(scf::ForOp &forOp,
getBlockedEncoding(loadOp, axisInfoAnalysis);
}
}

// TODO: b/381421713 - Remove this once pipelining is fixed.
if (!loadInfo.sharedEncoding) continue;
loadsToPipeline.insert(&op);

loadToInfo[&op] = loadInfo;
}
// Make sure all loads in loadsToPipeline are in loadToInfo.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,12 @@ mlir::triton::maybeGetStageCluster(Operation *op) {
}
std::pair<int, int> mlir::triton::getStageCluster(Operation *op) {
auto res = maybeGetStageCluster(op);
assert(res.has_value() || "Operation is missing stage & cluster attribute");
if (!res.has_value()) { // DO NOT SUBMIT
llvm::errs() << "op without stage & cluster:\n";
op->dump();
op->getParentOfType<tt::FuncOp>().dump();
}
assert(res.has_value() && "Operation is missing stage & cluster attribute");
return *res;
}

Expand Down
26 changes: 24 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
// opIdx: 0 => a, 1 => b
auto type = cast<triton::gpu::MemDescType>(v.getType());
SmallVector<int64_t> shape{type.getShape().begin(), type.getShape().end()};
SmallVector<int64_t> offset{0, 0};
SmallVector<int64_t> offset(shape.size(), 0);
Type elementType = type.getElementType();

// k => (prefetchWidth, k - prefetchWidth)
Expand All @@ -146,8 +146,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
type.getMutableMemory(), type.getAllocShape()),
v, offsetsVal);

// We need to assign kwidth to zero in the case where the parent layout is
// Blocked, otherwise the verifier emits a failure. The parent layout is
// Blocked only when Tensor Cores are disabled.
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
? 0
: prefetchWidth / 8;
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
builder.getContext(), opIdx, dotEncoding, kwidth);
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
Expand Down Expand Up @@ -197,6 +203,22 @@ LogicalResult Prefetcher::initialize() {
break;
if (!op->getResult(0).hasOneUse())
break;
// Similar to issues faced in HoistLayoutConversion pattern in
// OptimizeDotOperands.cpp, we can't propagate through type casts from
// predicates as they aren't supported in Triton when encoded with dot_op
// layout.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
break;
}
// Propagation through ExpandDims is currently not supported. This blindly
// replaces the encoding with dot encoding & but ExpandDims requires a
// SliceEncoding. This could be rewritten to support it somehow, but I
// don't think it's trivial & it's currently crashing.
if (isa<ExpandDimsOp>(op)) {
break;
}
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
// NYI for other encodings, for example if we have transpose
Expand Down
37 changes: 33 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class LayoutRematerialization {
SetVector<Operation *> opToDelete;
FuncOp funcOp;
DominanceInfo domInfo;
PostDominanceInfo postDomInfo;
};

void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Expand Down Expand Up @@ -1120,12 +1121,40 @@ void LayoutRematerialization::hoistConvertDotOperand(
ConvertLayoutOp convertOp) {
auto targetType = convertOp.getType();
// The pass is targeted to Nvidia mma/wgmma dot operands

// Partial cherry-pick of https://github.com/triton-lang/triton/pull/5475.
// Path 2 in b/391692127#comment28. Added check for parent being a for loop.
auto canBePipelined = [&](ConvertLayoutOp convertOp) {
auto parent = dyn_cast<scf::ForOp>(convertOp->getParentOp());
if (!parent)
return false;

// Find all the dot-like ops in the for loop that have a nvidia dot operand
// encoding on the lhs and check if any of them post-dominates the load +
// cvt
SmallVector<Operation *> dotLikeOps;
parent->walk([&](Operation *op) {
if (!isa<mlir::triton::DotOpInterface>(op))
return;
auto opType = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
if (!opType)
return;
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(opType.getEncoding());
if (!dotEnc)
return;
if (isa<NvidiaMmaEncodingAttr>(dotEnc.getParent()))
dotLikeOps.push_back(op);
});
if (dotLikeOps.empty())
return false;
return llvm::any_of(dotLikeOps, [&](Operation *dot) {
return postDomInfo.postDominates(dot, convertOp);
});
};

// We move convert #dot_operand next to their loads. This is done
// so that it's then easy to pipeline these loads
// TODO: Perhaps we should do this whenever convertOp is within a loop

auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
if (!(dotEnc && isa<NvidiaMmaEncodingAttr>(dotEnc.getParent())))
if (!canBePipelined(convertOp))
return;

// We hoist over any operation that can be done without data movement between
Expand Down
30 changes: 19 additions & 11 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,18 +1022,26 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
return std::nullopt;
auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType())
.getEncoding());
if (!dotOpEnc)
auto enc =
cast<triton::gpu::TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
if (isa<ttg::DotOperandEncodingAttr>(enc)) {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false);
} else if (enc.getAbstractAttribute().getName().str() ==
"triton.gpu.sparse_dot_meta_encoding") {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
ttg::getOrder(srcTy.getEncoding()),
ttg::getCTALayout(srcTy.getEncoding()));
} else {
return std::nullopt;
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType());
auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
auto order = ttg::getOrder(srcTy.getEncoding());
unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
tempAttr = ttg::SwizzledSharedEncodingAttr::get(
val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout,
bitWidth, /*needTrans=*/false);
}
}
// Check that the shared encodings needed by the users are compatible.
if (attr != nullptr && attr != tempAttr) {
Expand Down
Loading

0 comments on commit f6f77a1

Please sign in to comment.