Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@43d71baae36c
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist committed Feb 21, 2025
1 parent c95da49 commit 2bcd399
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 60 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ cc_library(
":linalg_passes",
":reference_api",
":reference_configuration",
":stablehlo_dialect_capi_objects",
":stablehlo_dialect_capi",
":stablehlo_ops",
":stablehlo_passes",
":stablehlo_portable_api",
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "0e779ad4998ef65907502101c5b82ede05ddfa4e"
LLVM_COMMIT = "43d71baae36c8d8b5a9995aa35efebe09cc9c2d6"

LLVM_SHA256 = "d5c2560b2d9ce3ced7951113f2b5d1ea428665678f4dcb1fb8780eb1219ca615"
LLVM_SHA256 = "436af8b4c3403e251ab0b7a471eda7df6063f9da9d22ccbe498f3115cd35225a"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0e779ad4998ef65907502101c5b82ede05ddfa4e
43d71baae36c8d8b5a9995aa35efebe09cc9c2d6
158 changes: 120 additions & 38 deletions stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ limitations under the License.

#include "stablehlo/dialect/ChloOps.h"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <optional>
#include <string>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -426,12 +430,12 @@ namespace {
// Mode 1, where the ragged dimension is an lhs non-contracting dim (m).
// lhs : [b, m, k]
// rhs : [g, b, k, n]
// group_sizes : [g]
// group_sizes : [b, g]
// result : [b, m, n]
// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k).
// lhs : [b, m, k]
// rhs : [b, k, n]
// group_sizes : [g]
// group_sizes : [b, g]
// result : [g, b, m, n]
// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b).
// lhs : [b, m, k]
Expand All @@ -440,9 +444,18 @@ namespace {
// result : [b, m, n]
// As with dot_general, the lhs and rhs can have arbitrary batching,
// contracting and non-contracting dimensions.
// The group_sizes arg has the shape [b...,x...,g], where:
// - b... are all the lhs batch dims before (outer-to) the lhs ragged dim,
// - x... are,
// - in mode 1, all the lhs non-contracting dims before the lhs ragged dim,
// - in mode 2, all the lhs contracting dims before the lhs ragged dim, and
// - in mode 3, empty;
// - g is the number of groups in the lhs ragged dim.
// Additionally:
// - In all modes, the lhs must have exactly one ragged dimension.
// - In mode 1, the rhs must have exactly one group dimension.
// - If a group_sizes of shape [g] is passed, it is broadcasted according to
// the rules above.
LogicalResult checkRaggedDotConstraints(
std::optional<Location> location, RankedTensorType rankedLhsType,
RankedTensorType rankedRhsType, RankedTensorType rankedGroupSizesType,
Expand All @@ -452,14 +465,6 @@ LogicalResult checkRaggedDotConstraints(
ArrayRef<int64_t> rhsContractingDimensions,
ArrayRef<int64_t> lhsRaggedDimensions,
ArrayRef<int64_t> rhsGroupDimensions) {
// Check that the group sizes has rank=1.
if (rankedGroupSizesType.getRank() != 1) {
return emitOptionalError(
location, "expected rank of group_sizes of ragged dot to be 1, got ",
rankedGroupSizesType.getRank());
}
auto numGroups = rankedGroupSizesType.getDimSize(0);

// Check that there is exactly one lhs ragged dimension.
if (lhsRaggedDimensions.size() != 1) {
return emitOptionalError(
Expand All @@ -474,6 +479,81 @@ LogicalResult checkRaggedDotConstraints(
return failure();
}

enum Mode {
// Ragged non-contracting (m): [b,m,k], [g,b,k,n], [b,g] -> [b,m,n].
kNonContracting,
// Ragged contracting (k): [b,m,k], [b,k,n], [b,g] -> [g,b,m,n].
kContracting,
// Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
kBatch
};
Mode mode;
if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim)) {
mode = kBatch;
} else if (llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
mode = kContracting;
} else {
mode = kNonContracting;
}

// Validate the shape of group_sizes.
{
// Construct the expected shape [b...,x...,g] of group_sizes.
SmallVector<int64_t> prefixDims;
prefixDims.reserve(rankedLhsType.getRank() - 1);
prefixDims.insert(prefixDims.end(), lhsBatchingDimensions.begin(),
lhsBatchingDimensions.end());
switch (mode) {
case kBatch:
prefixDims.resize(
std::distance(lhsBatchingDimensions.begin(),
llvm::find(lhsBatchingDimensions, lhsRaggedDim)));
break;
case kContracting:
prefixDims.insert(prefixDims.end(), lhsContractingDimensions.begin(),
llvm::find(lhsContractingDimensions, lhsRaggedDim));
break;
case kNonContracting:
for (int64_t i = 0; i < lhsRaggedDim; ++i) {
if (!llvm::is_contained(lhsBatchingDimensions, i) &&
!llvm::is_contained(lhsContractingDimensions, i)) {
prefixDims.push_back(i);
}
}
break;
}
SmallVector<int64_t> expectedPrefix;
expectedPrefix.reserve(prefixDims.size());
for (const int64_t dim : prefixDims) {
expectedPrefix.push_back(rankedLhsType.getDimSize(dim));
}

// Validate the actual shape, if it was passed as something other than [g].
if (rankedGroupSizesType.getRank() != 1) {
if (rankedGroupSizesType.getRank() != expectedPrefix.size() + 1) {
return emitOptionalError(location, "expected group_sizes to have rank ",
expectedPrefix.size() + 1, ", got ",
rankedGroupSizesType.getRank());
}
auto groupSizesShape = rankedGroupSizesType.getShape();
if (!std::equal(expectedPrefix.begin(), expectedPrefix.end(),
groupSizesShape.begin())) {
auto nonEmptyShapeStr = [](ArrayRef<int64_t> shape) {
std::string s = "";
for (int64_t i = 0; i < shape.size() - 1; ++i) {
s += std::to_string(shape[i]) + ", ";
}
return s + std::to_string(shape.back());
};
return emitOptionalError(
location, "group_sizes is expected to have shape [",
nonEmptyShapeStr(expectedPrefix), ", ", groupSizesShape.back(),
"], got [", nonEmptyShapeStr(groupSizesShape), "]");
}
}
}
const int64_t numGroups = rankedGroupSizesType.getShape().back();

// Validate basic properties of the rhs group dimension(s).
for (auto rhsGroupDim : rhsGroupDimensions) {
if (failed(hlo::checkDimInBounds(location, rhsGroupDim,
Expand All @@ -491,32 +571,34 @@ LogicalResult checkRaggedDotConstraints(
return failure();
}

if (llvm::is_contained(lhsBatchingDimensions, lhsRaggedDim) ||
llvm::is_contained(lhsContractingDimensions, lhsRaggedDim)) {
// Ragged batch (b): [b,m,k], [b,k,n], [g] -> [b,m,n].
// Ragged contracting (k): [b,m,k], [b,k,n], [g] -> [g,b,m,n].
if (!rhsGroupDimensions.empty()) {
return emitOptionalError(
location,
"There must be zero group dimensions in the rhs when the "
"ragged dimension is batch or contracting.");
}
} else {
// Ragged non-contracting (m): [b,m,k], [g,b,k,n], [g] -> [b,m,n].
if (rhsGroupDimensions.size() != 1) {
return emitOptionalError(
location,
"There must be exactly one group dimension in the rhs when the lhs "
"ragged dimension is non-contracting.");
}
// Compare the group dimension size with the number of groups.
const int64_t rhsGroupDim = rhsGroupDimensions[0];
if (!hlo::verifyCompatibleDims(numGroups,
rankedRhsType.getDimSize(rhsGroupDim))) {
return emitOptionalError(
location, "group_sizes is expected to have shape=[",
rankedRhsType.getDimSize(rhsGroupDim), "], got [", numGroups, "]");
}
switch (mode) {
case kBatch:
[[fallthrough]];
case kContracting:
if (!rhsGroupDimensions.empty()) {
return emitOptionalError(
location,
"There must be zero group dimensions in the rhs when the "
"ragged dimension is batch or contracting.");
}
break;
case kNonContracting:
if (rhsGroupDimensions.size() != 1) {
return emitOptionalError(
location,
"There must be exactly one group dimension in the rhs when the lhs "
"ragged dimension is non-contracting.");
}
// Compare the group dimension size with the number of groups.
const int64_t rhsGroupDim = rhsGroupDimensions[0];
if (!hlo::verifyCompatibleDims(numGroups,
rankedRhsType.getDimSize(rhsGroupDim))) {
return emitOptionalError(
location,
"rhs group dimension is expected to have size=", numGroups,
", got ", rankedRhsType.getDimSize(rhsGroupDim));
}
break;
}
return success();
}
Expand All @@ -530,10 +612,10 @@ SmallVector<int64_t> inferRaggedDotOutputDimensions(
ArrayRef<int64_t> rhsContractingDimensions,
ArrayRef<int64_t> lhsRaggedDimensions,
ArrayRef<int64_t> rhsGroupDimensions) {
// Must have already checked that group_sizes is 1-D.
const int64_t numGroups = rankedGroupSizesType.getDimSize(0);
// Must have already checked that there is exactly one lhs ragged dim.
const int64_t lhsRaggedDim = lhsRaggedDimensions[0];
// Must have already checked the shape of group_sizes.
const int64_t numGroups = rankedGroupSizesType.getShape().back();

SmallVector<int64_t> dimensions;
// Add the group dimension to the result shape in case of ragged contracting.
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,12 @@ def CHLO_RaggedDotOp : CHLO_Op<"ragged_dot",
most one group dimension. The op has three modes, depending on the kind of
the lhs ragged dimension.

In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [g] -> [b,m,n]`.
In mode 1, the shape-signature is `[b,m,k], [g,b,k,n], [b,g] -> [b,m,n]`.
Here the ragged dimension is an lhs non-contracting dimension (`m`). The
dimensions `b` and `k` represent batch and contracting dimensions
respectively. The rhs is required to have a group dimension (`g`).

In mode 2, the shape-signature is `[b,m,k], [b,k,n], [g] -> [g,b,m,n]`.
In mode 2, the shape-signature is `[b,m,k], [b,k,n], [b,g] -> [g,b,m,n]`.
Here the ragged dimension is an lhs/rhs contracting dimension (`k`).

In mode 3, the shape-signature is `[b,m,k], [b,k,n], [g] -> [b,m,n]`. Here
Expand Down
77 changes: 74 additions & 3 deletions stablehlo/tests/ops_chlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func.func @ragged_dot_incompatible_contracting_dims(%lhs : tensor<11x5xf32>, %rh
// -----

func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<3x2xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
// @expected-error@+1 {{expected group_sizes to have rank 1, got 2}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
Expand All @@ -163,8 +163,79 @@ func.func @ragged_dot_group_sizes_incorrect_rank(%lhs : tensor<11x5xf32>, %rhs :

// -----

func.func @ragged_dot_group_sizes_incorrect_shape(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
func.func @ragged_dot_mode1_group_sizes_broadcasted(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<19x17x11x7xf32> {
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [1],
lhs_contracting_dimensions = [3],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [2],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<3xi64>) -> tensor<19x17x11x7xf32>
func.return %0 : tensor<19x17x11x7xf32>
}

// -----

func.func @ragged_dot_mode1_group_sizes_incorrect_shape(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<3x19x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32> {
// @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [1],
lhs_contracting_dimensions = [3],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [2],
rhs_group_dimensions = [0]
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<19x17x11x5xf32>, tensor<3x19x5x7xf32>, tensor<19x11x3xi64>) -> tensor<19x17x11x7xf32>
func.return %0 : tensor<19x17x11x7xf32>
}

// -----

func.func @ragged_dot_mode2_group_sizes_incorrect_shape(%lhs : tensor<19x11x17x5xf32>, %rhs : tensor<19x17x5x7xf32>, %group_sizes : tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32> {
// @expected-error@+1 {{group_sizes is expected to have shape [19, 17, 3], got [19, 11, 3]}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2,3],
rhs_contracting_dimensions = [1,2],
lhs_ragged_dimensions = [3],
rhs_group_dimensions = []
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<19x11x17x5xf32>, tensor<19x17x5x7xf32>, tensor<19x11x3xi64>) -> tensor<3x19x11x7xf32>
func.return %0 : tensor<3x19x11x7xf32>
}

// -----

func.func @ragged_dot_mode3_group_sizes_incorrect_shape(%lhs : tensor<17x19x11x5xf32>, %rhs : tensor<17x19x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<17x19x11x7xf32> {
// @expected-error@+1 {{group_sizes is expected to have shape [17, 3], got [19, 3]}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [0,1],
rhs_batching_dimensions = [0,1],
lhs_contracting_dimensions = [3],
rhs_contracting_dimensions = [2],
lhs_ragged_dimensions = [1],
rhs_group_dimensions = []
>,
precision_config = [#chlo<precision DEFAULT>, #chlo<precision DEFAULT>]
} : (tensor<17x19x11x5xf32>, tensor<17x19x5x7xf32>, tensor<19x3xi64>) -> tensor<17x19x11x7xf32>
func.return %0 : tensor<17x19x11x7xf32>
}

// -----

func.func @ragged_dot_incorrect_group_dim_size(%lhs : tensor<11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<2xi64>) -> tensor<11x7xf32> {
// @expected-error@+1 {{rhs group dimension is expected to have size=2, got 3}}
%0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) {
ragged_dot_dimension_numbers = #chlo.ragged_dot<
lhs_batching_dimensions = [],
Expand Down
Loading

0 comments on commit 2bcd399

Please sign in to comment.