Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][TORCH] Add Onnx->Linalg lowering for RotaryEmbedding Op #4002

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

This commit adds the Onnx->Linalg lowering for Onnx's RotaryEmbedding op (ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding) by registering a customized torch op named OnnxVariantAtenRotaryEmbeddingOp. This is done so that the Onnx's RotaryEmbedding op can be lowered to this op and this op can be lowered from Torch->Linalg.

The lowering has been adopted from the OnnxRuntime. Files for references:
1.) https://github.com/microsoft/onnxruntime/blob/e1e3f623f61816008e79dddc91a51ffe7f0ff5cf/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc#L47-L93
2.) https://github.com/microsoft/onnxruntime/blob/94c69f55d480cb4a8dcbc161d29ef3acca9392a7/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h

Signed-off-by: Vivek Khandelwal vivekkhandelwal1424@gmail.com

@AmosLewis
Copy link
Collaborator

We need a test in https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests/operators to verify the numeric before merge

@vivekkhandelwal1 vivekkhandelwal1 force-pushed the rotary-embedding branch 2 times, most recently from 16f397a to caa9622 Compare February 10, 2025 10:10
@vivekkhandelwal1
Copy link
Collaborator Author

We need a test in https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests/operators to verify the numeric before merge

Actually, the test in SHARK-Testsuite is not working since the op comes from "com.microsoft" domain. Alhtough, I have verified the e2e correctness of lowering by manually generating the IR and then compiling and executing it with the IREE.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small comments. I haven't double checked that the implementation is correct.

Add custom parser and printer for the op
Move the op lowering to a seperate code file for com.microsoft domain ops
@zjgarvey
Copy link
Collaborator

Give me another day to consider this more. This is the first time we've added a custom torch op to try and support an onnx op, so I'd like to make sure we set a good precedent for how to approach this in the future.

Do you know if the output shape of this op is guaranteed to be the same as operand[0]?

@vivekkhandelwal1
Copy link
Collaborator Author

Give me another day to consider this more. This is the first time we've added a custom torch op to try and support an onnx op, so I'd like to make sure we set a good precedent for how to approach this in the future.

Do you know if the output shape of this op is guaranteed to be the same as operand[0]?

Yeah, the output shape is same as the input operand. It's mention here: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#outputs-80

@zjgarvey
Copy link
Collaborator

I submitted a PR after playing around with the shape and dtype inference for unregistered ops. Take a look at vivekkhandelwal1#4 and let me know what you think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants