Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to upstream StablehloToLinalg code. #19792

Merged
merged 13 commits into from
Jan 27, 2025
3 changes: 3 additions & 0 deletions build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(self, repo_map: Dict[str, str]):
"@stablehlo//:stablehlo_passes": [
"StablehloPasses",
],
"@stablehlo//:linalg_passes": [
"StablehloLinalgTransforms",
],
"@stablehlo//:vhlo_ops": [
"VhloOps",
],
Expand Down
10 changes: 1 addition & 9 deletions compiler/plugins/input/StableHLO/Conversion/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,8 @@ iree_compiler_cc_library(
"LegalizeToLinalgUtils.h",
"MapStableHLOToScalarOp.h",
"StableHLOCustomCalls.cpp",
"StableHLOToArith.cpp",
"StableHLOToIREEInputDialects.cpp",
"StableHLOToLinalg.cpp",
"StableHLOToLinalgConvolution.cpp",
"StableHLOToLinalgDotProd.cpp",
"StableHLOToLinalgExt.cpp",
"StableHLOToLinalgPointwise.cpp",
"StableHLOToLinalgRandom.cpp",
"StableHLOToLinalgReduce.cpp",
"TypeConversion.cpp",
"TypeConversion.h",
"VerifyCompilerInputLegality.cpp",
],
deps = [
Expand Down Expand Up @@ -121,6 +112,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:VectorDialect",
"@stablehlo//:broadcast_utils",
"@stablehlo//:chlo_ops",
"@stablehlo//:linalg_passes",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:vhlo_ops",
],
Expand Down
10 changes: 1 addition & 9 deletions compiler/plugins/input/StableHLO/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,8 @@ iree_cc_library(
"LegalizeToLinalgUtils.h"
"MapStableHLOToScalarOp.h"
"StableHLOCustomCalls.cpp"
"StableHLOToArith.cpp"
"StableHLOToIREEInputDialects.cpp"
"StableHLOToLinalg.cpp"
"StableHLOToLinalgConvolution.cpp"
"StableHLOToLinalgDotProd.cpp"
"StableHLOToLinalgExt.cpp"
"StableHLOToLinalgPointwise.cpp"
"StableHLOToLinalgRandom.cpp"
"StableHLOToLinalgReduce.cpp"
"TypeConversion.cpp"
"TypeConversion.h"
"VerifyCompilerInputLegality.cpp"
DEPS
::CHLODecompositionPatterns
Expand Down Expand Up @@ -99,6 +90,7 @@ iree_cc_library(
MLIRTransforms
MLIRVectorDialect
StablehloBroadcastUtils
StablehloLinalgTransforms
StablehloOps
VhloOps
iree::compiler::Dialect::Flow::IR
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,10 @@
#include <string>
#include <utility>

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
bool hasIntegralShapeType(Operation *op) {
auto stp = llvm::dyn_cast<ShapedType>(op->getOperand(0).getType());
return stp && stp.getElementType().isIntOrIndex();
}

} // namespace

SmallVector<utils::IteratorType, 3>
getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction) {
Expand All @@ -42,124 +31,6 @@ getNParallelLoopsAttrs(unsigned nParallelLoops) {
return getParallelAndReductionIterators(nParallelLoops, 0);
}

Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<bufferization::AllocTensorOp>(
loc, llvm::cast<TensorType>(type), dynSizes,
/*copy=*/Value(),
/*memory_space=*/IntegerAttr());
}

Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes) {
return b.create<tensor::EmptyOp>(
loc, type.getShape(), type.getElementType(), dynSizes,
llvm::cast<RankedTensorType>(type).getEncoding());
}

Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
Operation *op, ValueRange operands) {
bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr;
// Collect the sizes for a ranked tensor to be passed as parameter to a
// new tensor initialization operation. This operation only needs the
// dynamic sizes.
SmallVector<Value> sizes;
if (resultType.hasRank() && !resultType.hasStaticShape()) {
// Ask the op for its output shape.
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reifiedShapes;
(void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
assert(reifiedShapes.size() == 1 && "Expected one reified result");
// Construct sizes for the required dimensions.
for (const auto &en : llvm::enumerate(resultType.getShape())) {
if (!ShapedType::isDynamic(en.value()))
continue;
sizes.push_back(b.create<tensor::ExtractOp>(
loc, reifiedShapes[0],
ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
}
}
return isSparse ? getEmptySparseTensor(b, loc, resultType, sizes)
: getEmptyTensor(b, loc, resultType, sizes);
}

Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op) {
auto isRankedTensor = [](Value val) {
return isa<RankedTensorType>(val.getType());
};
if (!llvm::all_of(op->getOperands(), isRankedTensor))
return failure();
return success(llvm::all_of(op->getResults(), isRankedTensor));
}

Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor) {
auto type = cast<ShapedType>(tensor.getType());
Value zero;
// Complex numbers are a special case.
if (auto complexType = llvm::dyn_cast<ComplexType>(type.getElementType())) {
auto zeroElement = builder.getZeroAttr(complexType.getElementType());
auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement});
zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr);
} else {
auto zeroAttr = builder.getZeroAttr(type.getElementType());
zero = builder.create<arith::ConstantOp>(loc, zeroAttr);
}
return builder.create<linalg::FillOp>(loc, zero, tensor).result();
}

Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
OpBuilder *b) {
// Apply for semi-ring operations that lower to elaborate code
// (any sign-op, or an integral abs-op).
// TODO(peiming, ajcbik): these all can potentially be optimized by applying
// value transform on sparse_tenosr.value memref
if (isa<mlir::stablehlo::SignOp, mlir::stablehlo::NegOp>(op) ||
(isa<mlir::stablehlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
isa<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp, chlo::AtanhOp,
chlo::BesselI1eOp, chlo::SinhOp, chlo::TanOp>(op)) {
if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
!sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
return Value();
Location loc = op->getLoc();
auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]);
Type itp = values[0].getType();
Block *present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc);
b->setInsertionPointToStart(&semiring.getPresentRegion().front());
values[0] = present->getArgument(0);
return semiring;
}
return Value();
}

Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b) {
if (semiring) {
b->create<sparse_tensor::YieldOp>(op->getLoc(), result);
b->setInsertionPointAfter(semiring.getDefiningOp());
return semiring;
}
return result;
}

bool allOperandsAreScalarTensors(Operation *op) {
return llvm::all_of(op->getOperands(), [](Value operand) {
auto operandTy = llvm::dyn_cast<ShapedType>(operand.getType());
return operandTy && operandTy.getRank() == 0;
});
}

bool isInBodyOfLinalgOps(Operation *op) {
auto *parentOp = op->getParentRegion()->getParentOp();
return parentOp->getDialect() ==
parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
}

SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements) {
SmallVector<int64_t> ret;
for (const APInt &element : elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,16 @@
#include <string>
#include <utility>

#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
Expand All @@ -49,75 +39,9 @@ getParallelAndReductionIterators(unsigned nLoops, unsigned nReduction);
SmallVector<utils::IteratorType, 3>
getNParallelLoopsAttrs(unsigned nParallelLoops);

/// Generates an init sparse tensor.
Value getEmptySparseTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);

/// Generates a tensor.empty op.
Value getEmptyTensor(OpBuilder &b, Location loc, ShapedType type,
ArrayRef<Value> dynSizes);

/// Generates an empty tensor for the result of the operation, which could be a
/// dense tensor or a sparse tensor.
Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
Operation *op, ValueRange operands);

/// Ensures a tensor has the same shape (not including the element type) as
/// another.
Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType);

/// Verifies |op|'s semantics by checking if all operands and results have
/// ranged tensor types.
LogicalResult verifyHloOpBufferOrTensorSemantics(Operation *op);

/// Fills |tensor| with a zero constant of the matching type. Returns the new
/// value.
Value fillTensorWithZeros(OpBuilder &builder, Location loc, Value tensor);

/// Sparsifies a (block of) operation(s) that cannot be handled directly
/// by the sparse compiler but has well-known semi-ring semantics.
///
/// This yields something of the following form:
///
/// %result = sparse_tensor.unary %values[0]
/// present={
/// ^bb1(%val):
/// ... codegen proceeds here using %val ....
/// sparse_tensor.yield
/// }
/// absent={}
/// linalg.yield %result
Value preSparsify(Operation *op, llvm::SmallVector<Value, 2> &values, Type rtp,
OpBuilder *b);

/// Finalizes sparse semi-ring construction.
Value postSparsify(Operation *op, Value semiring, Value result, OpBuilder *b);

/// Returns true if all operands are tensors with rank 0.
bool allOperandsAreScalarTensors(Operation *op);

/// Returns true if parent op is linalg.
bool isInBodyOfLinalgOps(Operation *op);

/// Extracts integer values from the attribute |elements|.
SmallVector<int64_t> extract1DVector(DenseIntElementsAttr elements);

/// Returns true if the given |values| is a splat of the given |queryValue|.
inline bool isSplatValue(const ArrayRef<int64_t> &values, int64_t queryValue) {
for (auto value : values) {
if (value != queryValue) {
return false;
}
}
return true;
}

/// Returns true if the given |attr| is a splat of the given |value|.
inline bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}

} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_LEGALIZE_TO_LINALG_UTILS_H_
1 change: 0 additions & 1 deletion compiler/plugins/input/StableHLO/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ void buildStableHLOInputConversionPassPipelineImpl(
stablehlo::createConvertStableHloToLinalgExt());
passManager.addNestedPass<func::FuncOp>(stablehlo::createLegalizeChlo());
passManager.addPass(createConvertStableHloToIreeInputDialects());
// Ensure conversion completed.
passManager.addPass(createReconcileUnrealizedCastsPass());

// Note that some StableHLO ops are left by the above and must resolve via
Expand Down
9 changes: 2 additions & 7 deletions compiler/plugins/input/StableHLO/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
#include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class TypeConverter;
namespace iree_compiler::stablehlo {

std::unique_ptr<TypeConverter> createStableHloToLinalgTypeConverter();
namespace mlir::iree_compiler::stablehlo {

struct StableHloOptions : public PassPipelineOptions<StableHloOptions> {};

Expand All @@ -36,7 +32,6 @@ void buildStableHLOXLAInputConversionPassPipeline(

void registerStableHLOConversionPasses();

} // namespace iree_compiler::stablehlo
} // namespace mlir
} // namespace mlir::iree_compiler::stablehlo

#endif // IREE_COMPILER_PLUGINS_INPUT_STABLEHLO_CONVERSION_PASSES_H_
9 changes: 0 additions & 9 deletions compiler/plugins/input/StableHLO/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ def ConvertStableHloToLinalgExt :
// General passes
//===----------------------------------------------------------------------===//

def ConvertStableHloToLinalg :
Pass<"iree-stablehlo-to-linalg", "ModuleOp"> {
let summary = "Converts from StableHLO ops to Linalg ops on";
let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
/*default=*/"false",
"Lower to primitive Linalg ops (map, reduce and "
"transpose) when possible, instead of linalg.generic">];
}

def LegalizeControlFlow :
InterfacePass<"iree-stablehlo-legalize-control-flow", "mlir::FunctionOpInterface"> {
let summary = "Legalizes from StableHLO control flow to SCF control flow";
Expand Down
Loading
Loading