Skip to content

Commit

Permalink
Add source location downgrade to compatibility expander
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Feb 27, 2025
1 parent 03597b1 commit a7f73a1
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: stablehlo-opt %s -stablehlo-compatibility-expander=target=1.8.0 --mlir-print-debuginfo | FileCheck %s

// Test that FileLineColRange locations are converted to FileLineColLoc
// locations, including in nested location contexts, block args, module op, etc.
// Ex: loc("file.mlir":2:21 to :30) ==> loc("file.mlir":2:21)

#loc3 = loc("file.mlir":2:21 to :30)
module {
func.func @main(%arg0: tensor<i32> loc("file.mlir":2:21 to :30)) -> tensor<i32> {
%c = stablehlo.constant dense<1> : tensor<i32> loc(#loc4)
%0 = stablehlo.add %arg0, %c : tensor<i32> loc(#loc5)
return %0 : tensor<i32> loc(#loc6)
} loc(#loc9)
} loc(#loc)
#loc = loc("file.mlir":0:0 to :3)
#loc1 = loc("file.mlir":1:1 to :2)
#loc2 = loc("file.mlir":2:19 to :20)
#loc4 = loc("file.mlir":2:8 to :10)
#loc5 = loc("file.mlir":4:10 to :12)
#loc6 = loc("file.mlir":3:3 to :5)
#loc7 = loc("WrappedLocation.call"(#loc1))
#loc8 = loc("WrappedLocation.callsite"(#loc2))
#loc9 = loc(callsite(#loc7 at #loc8))

// CHECK: #[[LOC3:.*]] = loc("file.mlir":2:21)
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main{{.*}}tensor<i32> loc("file.mlir":2:21)
// CHECK-NEXT: stablehlo.constant {{.*}} loc(#[[LOC4:.*]])
// CHECK-NEXT: stablehlo.add {{.*}} : tensor<i32> loc(#[[LOC5:.*]])
// CHECK-NEXT: return {{.*}} loc(#[[LOC6:.*]])
// CHECK-NEXT: } loc(#[[LOC9:.*]])
// CHECK-NEXT: } loc(#[[LOC:.*]])
// CHECK-NEXT: #[[LOC]] = loc("file.mlir":0:0)
// CHECK-NEXT: #[[LOC1:.*]] = loc("file.mlir":1:1)
// CHECK-NEXT: #[[LOC2:.*]] = loc("file.mlir":2:19)
// CHECK-NEXT: #[[LOC4]] = loc("file.mlir":2:8)
// CHECK-NEXT: #[[LOC5]] = loc("file.mlir":4:10)
// CHECK-NEXT: #[[LOC6]] = loc("file.mlir":3:3)
// CHECK-NEXT: #[[LOC7:.*]] = loc("WrappedLocation.call"(#[[LOC1]]))
// CHECK-NEXT: #[[LOC8:.*]] = loc("WrappedLocation.callsite"(#[[LOC2]]))
// CHECK-NEXT: #[[LOC9]] = loc(callsite(#[[LOC7]] at #[[LOC8]]))

2 changes: 1 addition & 1 deletion stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism",
}];
}

def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander", "mlir::func::FuncOp"> {
def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander", "mlir::ModuleOp"> {
let summary = "Compatibility expander for StableHLO operations.";

let description = [{
Expand Down
70 changes: 63 additions & 7 deletions stablehlo/transforms/StablehloCompatibilityExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,44 @@ limitations under the License.
==============================================================================*/

#include <fcntl.h>
#include <stdbool.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <optional>
#include <utility>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/AttrTypeSubElements.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/PassUtils.h"
#include "stablehlo/transforms/PassUtils.h" // IWYU pragma: keep
#include "stablehlo/transforms/Passes.h"

#define DEBUG_TYPE "compat-passes"

namespace mlir {
namespace stablehlo {
#define GEN_PASS_DEF_STABLEHLOCOMPATIBILITYEXPANDERPASS
Expand Down Expand Up @@ -167,7 +177,7 @@ Value createConcatIndices(Value indices, int64_t indexVectorDim,
// Converts a `GatherOp` with batching dims to a `GatherOp` without batching
// dims, such that each batching dim becomes a collapsed slice dim with a
// corresponding `IotaOp` concatenated to the start indices.
class GatherWithBatchingDimsExpander : public OpRewritePattern<GatherOp> {
struct GatherWithBatchingDimsExpander : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GatherOp op,
Expand Down Expand Up @@ -213,7 +223,7 @@ class GatherWithBatchingDimsExpander : public OpRewritePattern<GatherOp> {
// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching
// dims, such that each batching dim becomes an inserted window dim with a
// corresponding `IotaOp` concatenated to the scatter indices.
class ScatterWithBatchingDimsExpander : public OpRewritePattern<ScatterOp> {
struct ScatterWithBatchingDimsExpander : public OpRewritePattern<ScatterOp> {
using OpRewritePattern<ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ScatterOp op,
Expand Down Expand Up @@ -262,6 +272,43 @@ class ScatterWithBatchingDimsExpander : public OpRewritePattern<ScatterOp> {
}
};

// FileLineColRange locations are a forward incompatibility in upstream MLIR.
// This pattern removes the precise start/end range information and converts
// all FileLineColRange locations to forward compatible FileLineColLoc
// locations.
struct FileLineColRangeToLoc : public OpRewritePattern<ModuleOp> {
using OpRewritePattern<ModuleOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ModuleOp op,
PatternRewriter &rewriter) const override {
bool changed = false;
mlir::AttrTypeReplacer replacer;
replacer.addReplacement([&](FileLineColLoc flcLoc)
-> std::optional<Location> {
// Skip if it's actually a FileLineColLoc
if (isStrictFileLineColLoc(flcLoc)) return flcLoc;

// Replace FileLineColRange with FileLineColLoc
changed = true;
auto newFlcLoc = FileLineColLoc::get(
flcLoc.getFilename(), flcLoc.getStartLine(), flcLoc.getStartColumn());
LLVM_DEBUG(llvm::dbgs()
<< "Rewriting FLC " << flcLoc << " -> " << newFlcLoc << "\n");
return newFlcLoc;
});

// Call this on the module to update all locations in the module.
// This should be safe since this pass is declared as a ModuleOp level pass
// in the pass TD file, so no async issues.
replacer.recursivelyReplaceElementsIn(op,
/*replaceAttrs=*/false,
/*replaceLocs=*/true,
/*replaceTypes=*/false);

return success(changed);
}
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
Expand All @@ -282,6 +329,7 @@ struct StablehloCompatibilityExpanderPass
auto targetVersion = validateTargetVersion(targetVersionOption);

config.useTopDownTraversal = true;

RewritePatternSet patterns_(context);
populateStablehloCompatibilityExpanderPatterns(&patterns_, context,
targetVersion);
Expand All @@ -290,9 +338,13 @@ struct StablehloCompatibilityExpanderPass
}

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError(
auto module = getOperation();

// Apply to both the module and its children
if (failed(
applyOpPatternsGreedily(module.getOperation(), patterns, config)) ||
failed(applyPatternsGreedily(module, patterns, config))) {
module.emitError(
"Failed to converge StableHLOCompatibilityExpanderPass in ")
<< config.maxIterations << " iterations";
signalPassFailure();
Expand Down Expand Up @@ -321,6 +373,10 @@ void populateStablehloCompatibilityExpanderPatterns(
if (targetVersion < vhlo::Version(1, 4, 0))
patterns->add<TanOp_ComplexElementType_CompatiblityExpander,
TanOp_CompatiblityExpander>(context);

// MLIR Upstream FileLineColRange introduced just before v1.8.4.
if (targetVersion < vhlo::Version(1, 8, 4))
patterns->add<FileLineColRangeToLoc>(context);
}

} // namespace stablehlo
Expand Down
30 changes: 30 additions & 0 deletions stablehlo/transforms/VhloToVersion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Types.h"
Expand Down Expand Up @@ -186,6 +187,28 @@ LogicalResult isLegalType(Type type, const Version& targetVersion) {
return success();
}

bool isLegalLocation(Location loc, const Version& targetVersion) {
// FileLineColRange locations are a forward incompatibility in upstream MLIR
// just before v1.8.4 was tagged. Support for downgrading these locations
// exists in StablehloCompatibilityExpanderPass.
bool isLegal = true;
loc->walk([&](Location childLoc) -> WalkResult {
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(childLoc)) {
static const Version kFileLineColLocMinVersion = Version(1, 8, 4);
if (!isStrictFileLineColLoc(loc) &&
targetVersion < kFileLineColLocMinVersion) {
LLVM_DEBUG(llvm::dbgs() << "failed to legalize location " << loc
<< " to version " << targetVersion << '\n');
isLegal = false;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});

return isLegal;
}

bool isLegalOperation(Operation* op, const Version& targetVersion) {
// Validate op
auto opInterface = dyn_cast<VersionedOpInterface>(op);
Expand All @@ -208,6 +231,7 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) {
return succeeded(isLegalAttribute(attr.getValue(), targetVersion));
};
if (!llvm::all_of(op->getAttrs(), isLegalAttrFn)) return false;
LLVM_DEBUG(llvm::dbgs() << "Legal op attributes for target. " << op << '\n');

// Validate types
auto isLegalTypeFn = [&](Type t) {
Expand All @@ -216,6 +240,12 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) {
if (!llvm::all_of(op->getOperandTypes(), isLegalTypeFn) ||
!llvm::all_of(op->getResultTypes(), isLegalTypeFn))
return false;
LLVM_DEBUG(llvm::dbgs() << "Legal op types for target. " << op << '\n');

// Validate location
if (!isLegalLocation(op->getLoc(), targetVersion)) return false;
LLVM_DEBUG(llvm::dbgs() << "Legal op location for target. " << op << '\n');

return true;
}

Expand Down

0 comments on commit a7f73a1

Please sign in to comment.