-
Notifications
You must be signed in to change notification settings - Fork 668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LinalgExt] Share FFT rewriting between Torch and StableHLO #19226
base: main
Are you sure you want to change the base?
Conversation
9c20742
to
382734b
Compare
382734b
to
9c20742
Compare
…ftRfft conversion. Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
…sition. Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should have a custom linalg_ext op to represent fft such that all input dialect can first lower to it. cc: @MaheshRavishankar @rsuderman
|
||
namespace mlir::iree_compiler::IREE::LinalgExt { | ||
|
||
std::tuple<LogicalResult, Value, Value> rewriteFft(Operation *op, Value operand, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably return FailureOr<std::pair<Value, Value>>
instead
int64_t fftLength, | ||
PatternRewriter &rewriter) { | ||
|
||
assert(!(fftLength & (fftLength - 1)) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use an llvm helper to check for powers of two: https://github.com/llvm/llvm-project/blob/7d1c661381d36018fd105f4ad4c2d6dc45e7288b/llvm/include/llvm/Support/MathExtras.h#L289-L298
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
9c20742
to
d00c0b4
Compare
@kuhar I have to force-push the rebased commits due to missing sign-offs. I'm addressing your feedback in the latest one. |
Assuming this PR is still active, can the assigned reviewers please comment on this? BTW @giacs-epic , feel free to ping if a PR sits without much activity for several days. |
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
@ScottTodd @kuhar @MaheshRavishankar @rsuderman |
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits. I'm not very familiar with FFT so I can't comment on the core logic, but overall makes sense to me.
If another reviewer has comments I'll defer to them, otherwise I'd be happy to approve.
@@ -39,4 +40,10 @@ collapseOpIterationDims(AttentionOp op, | |||
ArrayRef<ReassociationIndices> foldedIterationDims, | |||
RewriterBase &rewriter); | |||
|
|||
// Rewrite input rfft op (dialect-agnostic) into linalg_ext.fft. Return real | |||
// and imaginary tensor values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more details about the API here? In particular describe how op
, operand
, and fftLength
are used. An example using stableHLO or Torch would help too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, adding more details.
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/test/unstructured_linalg_ext.mlir
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp
Outdated
Show resolved
Hide resolved
@@ -296,125 +298,27 @@ struct ScatterOpConversion final | |||
struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drive-by since I'm trying to clean up some other StableHLO code in #19792
There is another lowering for stablehlo's FFT op to Linalg:
iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToIREEInputDialects.cpp
Lines 163 to 165 in bbe7f5c
/// Converts stablehlo.fft operation to Linalg ops. | |
struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> { | |
using OpConversionPattern::OpConversionPattern; |
I think that Linalg lowering won't run since this LinalgExt lowering runs first. It could be the case that this LinalgExt lowering doesn't support all variants and some fall through to the Linalg version though.
I'm planning on deleting that lowering to Linalg as part of my cleanup, but please speak up (or contribute test cases) if I'm overlooking some load bearing usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ScottTodd Thank you for writing.
Indeed the conversion I'm refactoring here applies only to cases where the input has power of 2 size (so it's a special case), all the other cases fall under the lowering you are removing in your PR.
I'm not knowledgeable about the StableHLO side of things, as I'm mainly focused on the Torch path. But it seems to me like the lowering your are removing is indeed the "main" one, although I'm surprised to see there are no tests exercising it currently (I only looked through lit tests though). Will it be replaced by the upstream code on openxla/stablehlo#1817 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll keep the StableHLO lowering then. Glad I checked.
Indeed, the LinalgExt lowering only powers power of 2 sizes:
iree/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
Lines 372 to 385 in 9a34131
// Only handle 2^n fft length. | |
auto operandType = | |
llvm::dyn_cast<RankedTensorType>(adaptor.getOperand().getType()); | |
if (!operandType || !operandType.hasStaticShape()) { | |
return failure(); | |
} | |
if (!llvm::all_equal(op.getFftLength())) { | |
return rewriter.notifyMatchFailure(op, "non-splat length"); | |
} | |
int fftLength = op.getFftLength().front(); | |
if (fftLength & (fftLength - 1)) { | |
return rewriter.notifyMatchFailure( | |
op, "expected FFT length to be a power of two"); | |
} |
The original change that added a lowering path added only lit and e2e tests for powers of two: #5450, then when a special lowering was added for LinalgExt , the e2e tests in https://github.com/iree-org/iree/blob/main/tests/e2e/stablehlo_ops/fft.mlir switched to using that path and the old path now appears to be untested.
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
@qedawkins Thank you for the review. I've applied the changes. |
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/RewriteFft.cpp
Outdated
Show resolved
Hide resolved
compiler/plugins/input/Torch/InputConversion/ConvertTorchUnstructuredToLinalgExt.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
Outdated
Show resolved
Hide resolved
Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com>
Take FFT rewriting code out of StableHLO to LinalgExt conversion pass and put it into a shared utility.
Implement a conversion pass for the Torch to IREE pipeline in which, whenever the input criterion is met (size is power of 2),
aten.fft_rfft
is converted tolinalgext.fft
. Decomposition ofaten.fft_rfft
is disabled for the rewriting to work.