Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Feb 21, 2025
1 parent bf6bed9 commit bd75fd9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
50 changes: 29 additions & 21 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1596,19 +1596,17 @@ namespace {

} // namespace

template <typename OpToCreate, typename Derived>
struct CustomOpMicrosoftQDuantizeLinear
: public OpRewritePattern<ONNXCustomOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
template <typename OpToCreate>
struct CustomOpMicrosoftQDuantizeLinear {
LogicalResult matchAndRewriteImpl(ONNXCustomOp customOp,
PatternRewriter &rewriter, StringRef expectedName) const {
using namespace onnx_mlir;

if (!isCustomMicrosoftOp(
customOp, static_cast<const Derived *>(this)->expectedName))
if (!isCustomMicrosoftOp(customOp, expectedName))
return failure();
assert(customOp->getNumOperands() == 3);
if (customOp->getNumOperands() != 3) {
return failure();
}

const auto scale = customOp->getOperand(1);
const auto zeroPoint = customOp->getOperand(2);
Expand All @@ -1623,7 +1621,13 @@ struct CustomOpMicrosoftQDuantizeLinear
zeroPoint);

IgnoreDiagnostic diag(customOp->getContext()->getDiagEngine());
if (failed(mlir::verify(newOp))) {
bool isNewOpValid;
if (auto info = newOp->getName().getRegisteredInfo()) {
isNewOpValid = succeeded(info->verifyInvariants(newOp));
} else {
isNewOpValid = succeeded(mlir::verify(newOp));
}
if (!isNewOpValid) {
rewriter.eraseOp(newOp);
return rewriter.notifyMatchFailure(customOp, "Failed verification");
}
Expand All @@ -1633,19 +1637,23 @@ struct CustomOpMicrosoftQDuantizeLinear
};

struct CustomOpMicrosoftQuantizeLinear
: public CustomOpMicrosoftQDuantizeLinear<ONNXQuantizeLinearOp,
CustomOpMicrosoftQuantizeLinear> {
const std::string expectedName = "QuantizeLinear";
using CustomOpMicrosoftQDuantizeLinear<ONNXQuantizeLinearOp,
CustomOpMicrosoftQuantizeLinear>::CustomOpMicrosoftQDuantizeLinear;
: public OpRewritePattern<ONNXCustomOp>,
public CustomOpMicrosoftQDuantizeLinear<ONNXQuantizeLinearOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
return matchAndRewriteImpl(customOp, rewriter, "QuantizeLinear");
}
};

struct CustomOpMicrosoftDequantizeLinear
: public CustomOpMicrosoftQDuantizeLinear<ONNXDequantizeLinearOp,
CustomOpMicrosoftDequantizeLinear> {
const std::string expectedName = "DequantizeLinear";
using CustomOpMicrosoftQDuantizeLinear<ONNXDequantizeLinearOp,
CustomOpMicrosoftDequantizeLinear>::CustomOpMicrosoftQDuantizeLinear;
: public OpRewritePattern<ONNXCustomOp>,
CustomOpMicrosoftQDuantizeLinear<ONNXDequantizeLinearOp> {
using OpRewritePattern<ONNXCustomOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
return matchAndRewriteImpl(customOp, rewriter, "DequantizeLinear");
}
};

// Transform InstanceNormalization into LayerNormalization
Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
MLIRContext *context = patterns.getContext();
patterns.insert<RecomposeGeluFromMulPattern>(context);
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
// AMD Disabled
// AMD Disabled as downstream has no special support for it
// patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
}

Expand Down

0 comments on commit bd75fd9

Please sign in to comment.