Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Share FFT rewriting between Torch and StableHLO #19226

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

giacs-epic
Copy link
Contributor

@giacs-epic giacs-epic commented Nov 20, 2024

Take FFT rewriting code out of StableHLO to LinalgExt conversion pass and put it into a shared utility.
Implement a conversion pass for the Torch to IREE pipeline in which, whenever the input criterion is met (size is power of 2), aten.fft_rfft is converted to linalgext.fft. Decomposition of aten.fft_rfft is disabled for the rewriting to work.

@giacs-epic giacs-epic changed the title [LinalgExt] Share fft rewriting between Torch and StableHLO [LinalgExt] Share FFT rewriting between Torch and StableHLO Dec 4, 2024
@giacs-epic giacs-epic marked this pull request as ready for review December 4, 2024 13:29
@giacs-epic giacs-epic force-pushed the aten_rfft_conversion branch from 9c20742 to 382734b Compare December 4, 2024 14:54
@giacs-epic giacs-epic force-pushed the aten_rfft_conversion branch from 382734b to 9c20742 Compare December 4, 2024 14:58
…ftRfft conversion.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
…sition.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should have a custom linalg_ext op to represent fft such that all input dialect can first lower to it. cc: @MaheshRavishankar @rsuderman


namespace mlir::iree_compiler::IREE::LinalgExt {

std::tuple<LogicalResult, Value, Value> rewriteFft(Operation *op, Value operand,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably return FailureOr<std::pair<Value, Value>> instead

int64_t fftLength,
PatternRewriter &rewriter) {

assert(!(fftLength & (fftLength - 1)) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
@giacs-epic giacs-epic force-pushed the aten_rfft_conversion branch from 9c20742 to d00c0b4 Compare December 5, 2024 10:48
@giacs-epic
Copy link
Contributor Author

@kuhar I have to force-push the rebased commits due to missing sign-offs. I'm addressing your feedback in the latest one.

@bjacob bjacob removed their request for review December 6, 2024 15:40
@ScottTodd
Copy link
Member

I wonder if we should have a custom linalg_ext op to represent fft such that all input dialect can first lower to it. cc: @MaheshRavishankar @rsuderman

Assuming this PR is still active, can the assigned reviewers please comment on this?

BTW @giacs-epic , feel free to ping if a PR sits without much activity for several days.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
@giacs-epic
Copy link
Contributor Author

@ScottTodd @kuhar @MaheshRavishankar @rsuderman
I was away for a while, I just updated the PR hoping to fix the CI failure. Aside from that, this is ready for review.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits. I'm not very familiar with FFT so I can't comment on the core logic, but overall makes sense to me.

If another reviewer has comments I'll defer to them, otherwise I'd be happy to approve.

@@ -39,4 +40,10 @@ collapseOpIterationDims(AttentionOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

// Rewrite input rfft op (dialect-agnostic) into linalg_ext.fft. Return real
// and imaginary tensor values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more details about the API here? In particular describe how op, operand, and fftLength are used. An example using stableHLO or Torch would help too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, adding more details.

@@ -296,125 +298,27 @@ struct ScatterOpConversion final
struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by since I'm trying to clean up some other StableHLO code in #19792

There is another lowering for stablehlo's FFT op to Linalg:

/// Converts stablehlo.fft operation to Linalg ops.
struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> {
using OpConversionPattern::OpConversionPattern;

I think that Linalg lowering won't run since this LinalgExt lowering runs first. It could be the case that this LinalgExt lowering doesn't support all variants and some fall through to the Linalg version though.

I'm planning on deleting that lowering to Linalg as part of my cleanup, but please speak up (or contribute test cases) if I'm overlooking some load bearing usage.

Copy link
Contributor Author

@giacs-epic giacs-epic Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd Thank you for writing.
Indeed the conversion I'm refactoring here applies only to cases where the input has power of 2 size (so it's a special case), all the other cases fall under the lowering you are removing in your PR.
I'm not knowledgeable about the StableHLO side of things, as I'm mainly focused on the Torch path. But it seems to me like the lowering your are removing is indeed the "main" one, although I'm surprised to see there are no tests exercising it currently (I only looked through lit tests though). Will it be replaced by the upstream code on openxla/stablehlo#1817 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll keep the StableHLO lowering then. Glad I checked.

Indeed, the LinalgExt lowering only powers power of 2 sizes:

// Only handle 2^n fft length.
auto operandType =
llvm::dyn_cast<RankedTensorType>(adaptor.getOperand().getType());
if (!operandType || !operandType.hasStaticShape()) {
return failure();
}
if (!llvm::all_equal(op.getFftLength())) {
return rewriter.notifyMatchFailure(op, "non-splat length");
}
int fftLength = op.getFftLength().front();
if (fftLength & (fftLength - 1)) {
return rewriter.notifyMatchFailure(
op, "expected FFT length to be a power of two");
}

The original change that added a lowering path added only lit and e2e tests for powers of two: #5450, then when a special lowering was added for LinalgExt , the e2e tests in https://github.com/iree-org/iree/blob/main/tests/e2e/stablehlo_ops/fft.mlir switched to using that path and the old path now appears to be untested.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
@giacs-epic
Copy link
Contributor Author

A few nits. I'm not very familiar with FFT so I can't comment on the core logic, but overall makes sense to me.

If another reviewer has comments I'll defer to them, otherwise I'd be happy to approve.

@qedawkins Thank you for the review. I've applied the changes.

Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants