-
Notifications
You must be signed in to change notification settings - Fork 533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX][TORCH] Add Onnx->Linalg lowering for RotaryEmbedding Op #4002
base: main
Are you sure you want to change the base?
Conversation
20809ca
to
846fbea
Compare
a788784
to
6b40a15
Compare
We need a test in https://github.com/nod-ai/SHARK-TestSuite/tree/main/alt_e2eshark/onnx_tests/operators to verify the numeric before merge |
16f397a
to
caa9622
Compare
Actually, the test in SHARK-Testsuite is not working since the op comes from "com.microsoft" domain. Alhtough, I have verified the e2e correctness of lowering by manually generating the IR and then compiling and executing it with the IREE. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small comments. I haven't double checked that the implementation is correct.
02705a1
to
7c72b59
Compare
Add custom parser and printer for the op Move the op lowering to a seperate code file for com.microsoft domain ops
35a43fc
to
04032da
Compare
Give me another day to consider this more. This is the first time we've added a custom torch op to try and support an onnx op, so I'd like to make sure we set a good precedent for how to approach this in the future. Do you know if the output shape of this op is guaranteed to be the same as |
Yeah, the output shape is same as the input operand. It's mention here: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#outputs-80 |
I submitted a PR after playing around with the shape and dtype inference for unregistered ops. Take a look at vivekkhandelwal1#4 and let me know what you think. |
This commit adds the Onnx->Linalg lowering for Onnx's RotaryEmbedding op (ref: https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftrotaryembedding) by registering a customized torch op named
OnnxVariantAtenRotaryEmbeddingOp
. This is done so that the Onnx's RotaryEmbedding op can be lowered to this op and this op can be lowered from Torch->Linalg.The lowering has been adopted from the OnnxRuntime. Files for references:
1.) https://github.com/microsoft/onnxruntime/blob/e1e3f623f61816008e79dddc91a51ffe7f0ff5cf/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc#L47-L93
2.) https://github.com/microsoft/onnxruntime/blob/94c69f55d480cb4a8dcbc161d29ef3acca9392a7/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Signed-off-by: Vivek Khandelwal vivekkhandelwal1424@gmail.com