-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemm #2841
base: main
Are you sure you want to change the base?
Gemm #2841
Conversation
Currently stuck on this. Any help will be appreciated! INFO onnx_ir::from_onnx: Parsing ONNX file: ./onnx-tests/tests/gemm/gemm.onnx
DEBUG onnx_ir::from_onnx: Number of nodes: 1
DEBUG onnx_ir::from_onnx: Number of inputs: 3
DEBUG onnx_ir::from_onnx: Number of initializers: 3
DEBUG onnx_ir::from_onnx: Number of outputs: 1
WARN onnx_ir::from_onnx: Input A is also an initializer. Initializer as default values are currently not supported
WARN onnx_ir::from_onnx: Input B is also an initializer. Initializer as default values are currently not supported
WARN onnx_ir::from_onnx: Input C is also an initializer. Initializer as default values are currently not supported
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Gemm"
DEBUG onnx_ir::from_onnx: renaming node "gemm_node"
DEBUG onnx_ir::from_onnx: adding node "gemm1"
INFO onnx_ir::from_onnx: Finished parsing ONNX file: ./onnx-tests/tests/gemm/gemm.onnx
DEBUG burn_import::onnx::to_burn: Writing debug graph file: "./out/gemm.graph.txt"
DEBUG burn_import::burn::graph: Registering node => 'gemm'
INFO burn_core::record::file: File exists, replacing
DEBUG burn_import::burn::graph: Building the scope nodes len => '1'
ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/burn/scope.rs:43:13:
No variable with name input1
thread 'main' panicked at crates/burn-import/src/burn/scope.rs:43:13:
No variable with name input1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works now with gemm.onnx file but not with the model in linear.onnx INFO onnx_ir::from_onnx: Parsing ONNX file: ./onnx-tests/tests/linear/linear.onnx
DEBUG onnx_ir::from_onnx: Number of nodes: 4
DEBUG onnx_ir::from_onnx: Number of inputs: 3
DEBUG onnx_ir::from_onnx: Number of initializers: 5
DEBUG onnx_ir::from_onnx: Number of outputs: 3
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Gemm"
DEBUG onnx_ir::from_onnx: renaming node "/linear1_with_gemm/Gemm"
DEBUG onnx_ir::from_onnx: adding node "gemm1"
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: renaming node "/linear2_with_matmul/MatMul"
DEBUG onnx_ir::coalesce: peeking next node for bias conversion
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: adding node "matmul1"
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: renaming node "/linear3_with_matmul/MatMul"
DEBUG onnx_ir::coalesce: peeking next node for bias conversion
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Add"
WARN onnx_ir::from_onnx: Input /linear3_with_matmul/MatMul_output_0 not found, should only happen when peeking
DEBUG onnx_ir::from_onnx: adding node "matmul2"
INFO onnx_ir::from_onnx: Finished parsing ONNX file: ./onnx-tests/tests/linear/linear.onnx
DEBUG burn_import::onnx::to_burn: Writing debug graph file: "./out/linear.graph.txt"
ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/burn/ty.rs:207:19:
"linear1_with_gemm.weight" is not a valid Ident
thread 'main' panicked at crates/burn-import/src/burn/ty.rs:207:19:
"linear1_with_gemm.weight" is not a valid Ident |
Ok will take a look at this in a bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"linear1_with_gemm.weight" is not a valid Ident
This is simply because the input name contains a period .
, and when we convert the node input names to variables this is invalid. We could apply a filter to replace periods with underscores in the name.
But I see that you have removed the gemm_to_linear conversion, we could keep that when it matches the pattern!
What would be the benefit of doing that though? Handling Gemm nodes as GEMM would keep things straightforward, in my opinion, like if someday, there is some bug related to gemm node, we can debug that directly instead of going through cases and wondering if the bug is in gemm's implementation, linear's implementation or the conversion logic, etc.... |
So should that be handled in this PR or a separate one? |
Well the conversion could be handled by the
It might not be required if we perform the conversion to linear, as it was handled before. But otherwise it could be included in this PR if it's required to make things work! |
Wait how do I actually do it though, any function trying to do so will have two output types, either linear and gemm or direct modification (no output) and gemm. Sorry I am a beginner to rust. |
Maybe I lost the context through the replies, but I'm not sure exactly what you mean by that. Generally speaking, you could return an enum and have two variants based on the input. But in this case specifically I am not sure exactly what function you are referring to (I'm guessing something to make a distinction between a regular Gemm node and a Linear gemm?) So as I said earlier, keeping the current LinearNode conversion might just be easier since it is a distinct Gemm node pattern that we can already handle. Paths for debugging should still be pretty clear even if the input is a Gemm node since we output the graph info. |
I was talking about a single function to delegate the responsibility of both cases, full gemm node and the one already implemented. As the convert to linear function is executed first before the actual conversion to rust code, I guess just removing the panics should be enough. Let me just try that! |
Could somebody help with this error- error[E0308]: mismatched types
--> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:49:9
|
47 | ) -> Tensor<B, 2> {
| ------------ expected `Tensor<B, 2>` because of return type
48 | let gemm1_out1 = (input1.matmul(input2) * 1f32) + (input3 * 1f32);
| ------------------------------------------------ here the type of `gemm1_out1` is inferred to be `Tensor<B, 3>`
49 | gemm1_out1
| ^^^^^^^^^^ expected `2`, found `3`
|
= note: expected struct `Tensor<_, _, 2>`
found struct `Tensor<_, _, 3>`
error[E0061]: this method takes 3 arguments but 2 arguments were supplied
--> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:28
|
2327 | let output = model.forward(a, b);
| ^^^^^^^------ argument #3 of type `Tensor<NdArray, 3>` is missing
|
note: expected `3`, found `2`
--> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:36
|
2327 | let output = model.forward(a, b);
| ^
= note: expected struct `Tensor<_, _, 3>`
found struct `Tensor<_, _, 2>`
note: expected `3`, found `2`
--> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:39
|
2327 | let output = model.forward(a, b);
| ^
= note: expected struct `Tensor<_, _, 3>`
found struct `Tensor<_, _, 2>`
note: method defined here
--> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:42:12
|
42 | pub fn forward(
| ^^^^^^^
43 | &self,
44 | input1: Tensor<B, 3>,
| --------------------
45 | input2: Tensor<B, 3>,
| --------------------
46 | input3: Tensor<B, 3>,
| --------------------
help: provide the argument
|
2327 | let output = model.forward(/* Tensor<NdArray, 3> */, /* Tensor<NdArray, 3> */, /* Tensor<NdArray, 3> */); I can't see how the rank is ending up at 3?! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could somebody help with this error-
[...]
I can't see how the rank is ending up at 3?!
The test case you defined doesn't match the ONNX model in onnx-tests/tests/gemm/gemm.onnx
, which is the one you're using given the defined input in the build.rs`.
That model expects three inputs of shape [2, 2, 3]
, [2, 3, 4]
and [2, 2, 4]
respectively. Hence, the compiler correctly points out the error:
error[E0061]: this method takes 3 arguments but 2 arguments were supplied
--> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:28
|
2327 | let output = model.forward(a, b);
[...]
note: method defined here
--> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:42:12
|
42 | pub fn forward(
| ^^^^^^^
43 | &self,
44 | input1: Tensor<B, 3>,
| --------------------
45 | input2: Tensor<B, 3>,
| --------------------
46 | input3: Tensor<B, 3>,
| --------------------
If you want to have tests for different node configurations, you might need to have additional ONNX models (to include accordingly in the onnx_tests
).
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2841 +/- ##
==========================================
+ Coverage 82.30% 82.31% +0.01%
==========================================
Files 863 864 +1
Lines 116968 117153 +185
==========================================
+ Hits 96268 96438 +170
- Misses 20700 20715 +15 ☔ View full report in Codecov by Sentry. |
The generated codecov report contains regression only on the files not even part of this PR, and I haven't removed any test?? New to these things btw, so sorry if it is something stupid |
You can safely ignore, the codecov bot probably detected those as changes since they were included from the merge main commit 😅 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good job! Glad to see you managed to make it work.
The current implementation is a bit restrictive, so we should either cover the whole scope or panic with a clear message for unsupported configurations (e.g., C
is a scalar).
We should probably add some other tests since the Gemm
node is only tested for all inputs A
, B,
and C
of the same shape. Not strictly required to cover all the possible combinations since there are quite a few optional attributes and conditions in this particular case. For example, we could add another model (or include another node in the current ONNX) with one of the inputs being transposed and no C
. If we add support for scalar C
input, we should also include it in a test.
See also my comments on the implementation below 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments left, mostly need to correctly handle the output rank and ideally add some tests for other node configurations.
Should be good to go after that! 🙂
}; | ||
|
||
#[test] | ||
fn test_codegen_nodes() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I mentioned in the last review, we should add tests for some of the other configurations! Especially for the alpha
or beta
!= 1.0 and no C
.
At the very least for the codegen, the test ONNX model itself is not as strict of a requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How am I supposed to set the values of attrs in the test? And each test would have to have a different test_codegen?
Could you please tell how to do these things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, guess I should have been more precise 😅
a different test_codegen?
Yep, precisely! Basically each instance can be defined as a separate codegen test. See for example the gather node: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/burn/node/gather.rs#L150
They're separate test cases after all :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added two new test cases, one for alpha=beta=0.5 and one for no c. Although codegen test works fine, it does not care what the output is? Could you take a look at that
For example-
let expected =
Tensor::<Backend, 2>::from_data(TensorData::from([[19.0, 22.], [43., 50.]]), &device);
changing the values in this makes no difference?
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2049
#1544
Changes
Implement full Gemm Node for ONNX
Testing
Not yet testedAll tests passing on my machine, except for the remainder ones #2787