Skip to content

Commit

Permalink
Switch to upstream StablehloToLinalg code. (iree-org#19792)
Browse files Browse the repository at this point in the history
While looking at compiler warnings in build logs, I noticed paths in
StableHLO that looked out of place. As it turns out, much of IREE's
StableHLO to Linalg conversion code was forked into upstream StableHLO
in openxla/stablehlo#1817, though there have
been some local changes to the code here since it was forked:
https://github.com/iree-org/iree/commits/main/compiler/plugins/input/StableHLO/Conversion.

Switching to use the upstream code will allow us to decrease the surface
area we directly support and limit the number of files we need to build
from source, but it will also make maintenance require coordinating more
with upstream (such as during LLVM integrates). We still point to a fork
at https://github.com/iree-org/stablehlo , so if things get tricky we
can choose to set up a branch with patches as needed.

Some notes:

* More code, particularly includes and build dependencies, could be
pruned.
* We can probably delete more code by reviving
iree-org#18681 too
* I deleted lit tests for the patterns that were moved upstream. The
tests still exist at
https://github.com/openxla/stablehlo/tree/main/stablehlo/conversions/linalg/tests,
but I don't see much value in having our own versions of the lit tests.
We do still have e2e tests that compile and run.
* I did _not_ plumb through the `enablePrimitiveOps` or
`enableSparseOps` options, which may be useful for some programs
* I'm keeping our custom `stablehlo.concatenate` lowering since the
alternate lowering (from IREE or now StableHLO) has correctness issues.
Also keeping the FFT lowering since that does not exist upstream and it
handles cases that our LinalgExt lowering does not.

Signed-off-by: Hyunsung Lee <ita9naiwa@gmail.com>
  • Loading branch information
ScottTodd authored and ita9naiwa committed Feb 4, 2025
1 parent 18eadf7 commit b6585f9
Show file tree
Hide file tree
Showing 28 changed files with 31 additions and 12,133 deletions.
3 changes: 3 additions & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(self, repo_map: Dict[str, str]):
"@stablehlo//:stablehlo_passes": [
"StablehloPasses",
],
"@stablehlo//:linalg_passes": [
"StablehloLinalgTransforms",
],
"@stablehlo//:vhlo_ops": [
"VhloOps",
],
Expand Down
10 changes: 1 addition & 9 deletions compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,8 @@ iree_compiler_cc_library(
"LegalizeToLinalgUtils.h",
"MapStableHLOToScalarOp.h",
"StableHLOCustomCalls.cpp",
"StableHLOToArith.cpp",
"StableHLOToIREEInputDialects.cpp",
"StableHLOToLinalg.cpp",
"StableHLOToLinalgConvolution.cpp",
"StableHLOToLinalgDotProd.cpp",
"StableHLOToLinalgExt.cpp",
"StableHLOToLinalgPointwise.cpp",
"StableHLOToLinalgRandom.cpp",
"StableHLOToLinalgReduce.cpp",
"TypeConversion.cpp",
"TypeConversion.h",
"VerifyCompilerInputLegality.cpp",
],
deps = [
Expand Down Expand Up @@ -121,6 +112,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:VectorDialect",
"@stablehlo//:broadcast_utils",
"@stablehlo//:chlo_ops",
"@stablehlo//:linalg_passes",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:vhlo_ops",
],
Expand Down
10 changes: 1 addition & 9 deletions compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,8 @@ iree_cc_library(
"LegalizeToLinalgUtils.h"
"MapStableHLOToScalarOp.h"
"StableHLOCustomCalls.cpp"
"StableHLOToArith.cpp"
"StableHLOToIREEInputDialects.cpp"
"StableHLOToLinalg.cpp"
"StableHLOToLinalgConvolution.cpp"
"StableHLOToLinalgDotProd.cpp"
"StableHLOToLinalgExt.cpp"
"StableHLOToLinalgPointwise.cpp"
"StableHLOToLinalgRandom.cpp"
"StableHLOToLinalgReduce.cpp"
"TypeConversion.cpp"
"TypeConversion.h"
"VerifyCompilerInputLegality.cpp"
DEPS
::CHLODecompositionPatterns
Expand Down Expand Up @@ -99,6 +90,7 @@ iree_cc_library(
MLIRTransforms
MLIRVectorDialect
StablehloBroadcastUtils
StablehloLinalgTransforms
StablehloOps
VhloOps
iree::compiler::Dialect::Flow::IR
Expand Down
129 changes: 0 additions & 129 deletions compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,10 @@
#include <string>
#include <utility>

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
bool hasIntegralShapeType(Operation *op) {
auto stp = llvm::dyn_cast<ShapedType>(op->getOperand(0).getType());
return stp && stp.getElementType().isIntOrIndex();
}

} // namespace

SmallVector<utils::IteratorType, 3>
getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) {
Expand All @@ -42,124 +31,6 @@ getNParallelLoopsAttrs(unsigned nParallelLoops) {
return getParallelAndReductionIterators(nParallelLoops, 0);
}

Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<bufferization::AllocTensorOp>(
loc, llvm::cast<TensorType>(type), dynSizes,
/*copy=*/Value(),
/*memory_space=*/IntegerAttr());
}

Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<tensor::EmptyOp>(
loc, type.getShape(), type.getElementType(), dynSizes,
llvm::cast<RankedTensorType>(type).getEncoding());
}

Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
Operation *op, ValueRange operands) {
bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr;
// Collect the sizes for a ranked tensor to be passed as parameter to a
// new tensor initialization operation. This operation only needs the
// dynamic sizes.
SmallVector<Value> sizes;
if (resultType.hasRank() && !resultType.hasStaticShape()) {
// Ask the op for its output shape.
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reifiedShapes;
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
assert(reifiedShapes.size() == 1 && "Expected one reified result");
// Construct sizes for the required dimensions.
for (const auto &en : llvm::enumerate(resultType.getShape())) {
if (!ShapedType::isDynamic(en.value()))
continue;
sizes.push_back(b.create<tensor::ExtractOp>(
loc, reifiedShapes[0],
ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
}
}
return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes)
: getEmptyTensor(b, loc, resultType, sizes);
}

Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op) {
auto isRankedTensor = [](Value val) {
return isa<RankedTensorType>(val.getType());
};
if (!llvm::all_of(op->getOperands(), isRankedTensor))
return failure();
return success(llvm::all_of(op->getResults(), isRankedTensor));
}

Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor) {
auto type = cast<ShapedType>(tensor.getType());
Value zero;
// Complex numbers are a special case.
if (auto complexType = llvm::dyn_cast<ComplexType>(type.getElementType())) {
auto zeroElement = builder.getZeroAttr(complexType.getElementType());
auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement});
zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr);
} else {
auto zeroAttr = builder.getZeroAttr(type.getElementType());
zero = builder.create<arith::ConstantOp>(loc, zeroAttr);
}
return builder.create<linalg::FillOp>(loc, zero, tensor).result();
}

Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
OpBuilder *b) {
// Apply for semi-ring operations that lower to elaborate code
// (any sign-op, or an integral abs-op).
// TODO(peiming, ajcbik): these all can potentially be optimized by applying
// value transform on sparse_tenosr.value memref
if (isa<mlir::stablehlo::SignOp, mlir::stablehlo::NegOp>(op) ||
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
isa<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp, chlo::AtanhOp,
chlo::BesselI1eOp, chlo::SinhOp, chlo::TanOp>(op)) {
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
return Value();
Location loc = op->getLoc();
auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]);
Type itp = values[0].getType();
Block *present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc);
b->setInsertionPointToStart(&semiring.getPresentRegion().front());
values[0] = present->getArgument(0);
return semiring;
}
return Value();
}

Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b) {
if (semiring) {
b->create<sparse_tensor::YieldOp>(op->getLoc(), result);
b->setInsertionPointAfter(semiring.getDefiningOp());
return semiring;
}
return result;
}

bool allOperandsAreScalarTensors(Operation *op) {
return llvm::all_of(op->getOperands(), [](Value operand) {
auto operandTy = llvm::dyn_cast<ShapedType>(operand.getType());
return operandTy && operandTy.getRank() == 0;
});
}

bool isInBodyOfLinalgOps(Operation *op) {
auto *parentOp = op->getParentRegion()->getParentOp();
return parentOp->getDialect() ==
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
}

SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements) {
SmallVector<int64_t> ret;
for (const APInt &element : elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,16 @@
#include <string>
#include <utility>

#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
Expand All @@ -49,75 +39,9 @@ getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction);
SmallVector<utils::IteratorType, 3>
getNParallelLoopsAttrs(unsigned nParallelLoops);

/// Generates an init sparse tensor.
Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);

/// Generates a tensor.empty op.
Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);

/// Generates an empty tensor for the result of the operation, which could be a
/// dense tensor or a sparse tensor.
Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
Operation *op, ValueRange operands);

/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType);

/// Verifies |op|'s semantics by checking if all operands and results have
/// ranged tensor types.
LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op);

/// Fills |tensor| with a zero constant of the matching type. Returns the new
/// value.
Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor);

/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
/// This yields something of the following form:
///
/// %result = sparse_tensor.unary %values[0]
/// present={
/// ^bb1(%val):
/// ... codegen proceeds here using %val ....
/// sparse_tensor.yield
/// }
/// absent={}
/// linalg.yield %result
Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
OpBuilder *b);

/// Finalizes sparse semi-ring construction.
Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b);

/// Returns true if all operands are tensors with rank 0.
bool allOperandsAreScalarTensors(Operation *op);

/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation *op);

/// Extracts integer values from the attribute |elements|.
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements);

/// Returns true if the given |values| is a splat of the given |queryValue|.
inline bool isSplatValue(const ArrayRef<int64_t> &values, int64_t queryValue) {
for (auto value : values) {
if (value != queryValue) {
return false;
}
}
return true;
}

/// Returns true if the given |attr| is a splat of the given |value|.
inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
1 change: 0 additions & 1 deletion compiler/plugins/input/StableHLO/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ void buildStableHLOInputConversionPassPipelineImpl(
stablehlo::createConvertStableHloToLinalgExt());
passManager.addNestedPass<func::FuncOp>(stablehlo::createLegalizeChlo());
passManager.addPass(createConvertStableHloToIreeInputDialects());
// Ensure conversion completed.
passManager.addPass(createReconcileUnrealizedCastsPass());

// Note that some StableHLO ops are left by the above and must resolve via
Expand Down
9 changes: 2 additions & 7 deletions compiler/plugins/input/StableHLO/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
#include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class TypeConverter;
namespace iree_compiler::stablehlo {

std::unique_ptr<TypeConverter> createStableHloToLinalgTypeConverter();
namespace mlir::iree_compiler::stablehlo {

struct StableHloOptions : public PassPipelineOptions<StableHloOptions> {};

Expand All @@ -36,7 +32,6 @@ void buildStableHLOXLAInputConversionPassPipeline(

void registerStableHLOConversionPasses();

} // namespace iree_compiler::stablehlo
} // namespace mlir
} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES_H_
9 changes: 0 additions & 9 deletions compiler/plugins/input/StableHLO/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ def ConvertStableHloToLinalgExt :
// General passes
//===----------------------------------------------------------------------===//

def ConvertStableHloToLinalg :
Pass<"iree-stablehlo-to-linalg", "ModuleOp"> {
let summary = "Converts from StableHLO ops to Linalg ops on";
let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
/*default=*/"false",
"Lower to primitive Linalg ops (map, reduce and "
"transpose) when possible, instead of linalg.generic">];
}

def LegalizeControlFlow :
InterfacePass<"iree-stablehlo-legalize-control-flow", "mlir::FunctionOpInterface"> {
let summary = "Legalizes from StableHLO control flow to SCF control flow";
Expand Down
Loading

0 comments on commit b6585f9

Please sign in to comment.