Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemm #2841

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Gemm #2841

wants to merge 13 commits into from

Conversation

akshitgaur2005
Copy link
Contributor

@akshitgaur2005 akshitgaur2005 commented Feb 24, 2025

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2049
#1544

Changes

Implement full Gemm Node for ONNX

Testing

Not yet tested

All tests passing on my machine, except for the remainder ones #2787

@akshitgaur2005
Copy link
Contributor Author

Currently stuck on this. Any help will be appreciated!

 INFO onnx_ir::from_onnx: Parsing ONNX file: ./onnx-tests/tests/gemm/gemm.onnx
DEBUG onnx_ir::from_onnx: Number of nodes: 1
DEBUG onnx_ir::from_onnx: Number of inputs: 3
DEBUG onnx_ir::from_onnx: Number of initializers: 3
DEBUG onnx_ir::from_onnx: Number of outputs: 1
 WARN onnx_ir::from_onnx: Input A is also an initializer. Initializer as default values are currently not supported
 WARN onnx_ir::from_onnx: Input B is also an initializer. Initializer as default values are currently not supported
 WARN onnx_ir::from_onnx: Input C is also an initializer. Initializer as default values are currently not supported
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Gemm"
DEBUG onnx_ir::from_onnx: renaming node "gemm_node"
DEBUG onnx_ir::from_onnx: adding node "gemm1"
 INFO onnx_ir::from_onnx: Finished parsing ONNX file: ./onnx-tests/tests/gemm/gemm.onnx
DEBUG burn_import::onnx::to_burn: Writing debug graph file: "./out/gemm.graph.txt"
DEBUG burn_import::burn::graph: Registering node => 'gemm'
 INFO burn_core::record::file: File exists, replacing
DEBUG burn_import::burn::graph: Building the scope nodes len => '1'
ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/burn/scope.rs:43:13:
No variable with name input1
thread 'main' panicked at crates/burn-import/src/burn/scope.rs:43:13:
No variable with name input1

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gemm.onnx model file is invalid, the graph doesn't even have any inputs.

image

You need to define the graph inputs to be used by the Gemm node.

@akshitgaur2005
Copy link
Contributor Author

It works now with gemm.onnx file but not with the model in linear.onnx

 INFO onnx_ir::from_onnx: Parsing ONNX file: ./onnx-tests/tests/linear/linear.onnx
DEBUG onnx_ir::from_onnx: Number of nodes: 4
DEBUG onnx_ir::from_onnx: Number of inputs: 3
DEBUG onnx_ir::from_onnx: Number of initializers: 5
DEBUG onnx_ir::from_onnx: Number of outputs: 3
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Gemm"
DEBUG onnx_ir::from_onnx: renaming node "/linear1_with_gemm/Gemm"
DEBUG onnx_ir::from_onnx: adding node "gemm1"
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: renaming node "/linear2_with_matmul/MatMul"
DEBUG onnx_ir::coalesce: peeking next node for bias conversion
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: adding node "matmul1"
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "MatMul"
DEBUG onnx_ir::from_onnx: renaming node "/linear3_with_matmul/MatMul"
DEBUG onnx_ir::coalesce: peeking next node for bias conversion
DEBUG onnx_ir::proto_conversion: Converting ONNX node with type "Add"
 WARN onnx_ir::from_onnx: Input /linear3_with_matmul/MatMul_output_0 not found, should only happen when peeking
DEBUG onnx_ir::from_onnx: adding node "matmul2"
 INFO onnx_ir::from_onnx: Finished parsing ONNX file: ./onnx-tests/tests/linear/linear.onnx
DEBUG burn_import::onnx::to_burn: Writing debug graph file: "./out/linear.graph.txt"
ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/burn/ty.rs:207:19:
"linear1_with_gemm.weight" is not a valid Ident
thread 'main' panicked at crates/burn-import/src/burn/ty.rs:207:19:
"linear1_with_gemm.weight" is not a valid Ident

@laggui
Copy link
Member

laggui commented Feb 25, 2025

It works now with gemm.onnx file but not with the model in linear.onnx

Ok will take a look at this in a bit

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"linear1_with_gemm.weight" is not a valid Ident

This is simply because the input name contains a period ., and when we convert the node input names to variables this is invalid. We could apply a filter to replace periods with underscores in the name.

But I see that you have removed the gemm_to_linear conversion, we could keep that when it matches the pattern!

@akshitgaur2005
Copy link
Contributor Author

But I see that you have removed the gemm_to_linear conversion, we could keep that when it matches the pattern!

What would be the benefit of doing that though?

Handling Gemm nodes as GEMM would keep things straightforward, in my opinion, like if someday, there is some bug related to gemm node, we can debug that directly instead of going through cases and wondering if the bug is in gemm's implementation, linear's implementation or the conversion logic, etc....

@akshitgaur2005
Copy link
Contributor Author

This is simply because the input name contains a period ., and when we convert the node input names to variables this is invalid. We could apply a filter to replace periods with underscores in the name.

So should that be handled in this PR or a separate one?

@laggui
Copy link
Member

laggui commented Feb 26, 2025

But I see that you have removed the gemm_to_linear conversion, we could keep that when it matches the pattern!

What would be the benefit of doing that though?

Well the conversion could be handled by the GemmNode instead. If the inputs to the Gemm are initializers, then surely it was meant to be exported as linear layer. So importing it in Burn with a linear layer makes sense imo 🤔

So should that be handled in this PR or a separate one?

It might not be required if we perform the conversion to linear, as it was handled before. But otherwise it could be included in this PR if it's required to make things work!

@akshitgaur2005
Copy link
Contributor Author

Wait how do I actually do it though, any function trying to do so will have two output types, either linear and gemm or direct modification (no output) and gemm.

Sorry I am a beginner to rust.

@laggui
Copy link
Member

laggui commented Feb 27, 2025

Wait how do I actually do it though, any function trying to do so will have two output types, either linear and gemm or direct modification (no output) and gemm.

Maybe I lost the context through the replies, but I'm not sure exactly what you mean by that.

Generally speaking, you could return an enum and have two variants based on the input. But in this case specifically I am not sure exactly what function you are referring to (I'm guessing something to make a distinction between a regular Gemm node and a Linear gemm?)

So as I said earlier, keeping the current LinearNode conversion might just be easier since it is a distinct Gemm node pattern that we can already handle. Paths for debugging should still be pretty clear even if the input is a Gemm node since we output the graph info.

@akshitgaur2005
Copy link
Contributor Author

I was talking about a single function to delegate the responsibility of both cases, full gemm node and the one already implemented.

As the convert to linear function is executed first before the actual conversion to rust code, I guess just removing the panics should be enough. Let me just try that!

@akshitgaur2005
Copy link
Contributor Author

Could somebody help with this error-

error[E0308]: mismatched types
  --> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:49:9
   |
47 |     ) -> Tensor<B, 2> {
   |          ------------ expected `Tensor<B, 2>` because of return type
48 |         let gemm1_out1 = (input1.matmul(input2) * 1f32) + (input3 * 1f32);
   |                          ------------------------------------------------ here the type of `gemm1_out1` is inferred to be `Tensor<B, 3>`
49 |         gemm1_out1
   |         ^^^^^^^^^^ expected `2`, found `3`
   |
   = note: expected struct `Tensor<_, _, 2>`
              found struct `Tensor<_, _, 3>`

error[E0061]: this method takes 3 arguments but 2 arguments were supplied
    --> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:28
     |
2327 |         let output = model.forward(a, b);
     |                            ^^^^^^^------ argument #3 of type `Tensor<NdArray, 3>` is missing
     |
note: expected `3`, found `2`
    --> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:36
     |
2327 |         let output = model.forward(a, b);
     |                                    ^
     = note: expected struct `Tensor<_, _, 3>`
                found struct `Tensor<_, _, 2>`
note: expected `3`, found `2`
    --> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:39
     |
2327 |         let output = model.forward(a, b);
     |                                       ^
     = note: expected struct `Tensor<_, _, 3>`
                found struct `Tensor<_, _, 2>`
note: method defined here
    --> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:42:12
     |
42   |     pub fn forward(
     |            ^^^^^^^
43   |         &self,
44   |         input1: Tensor<B, 3>,
     |         --------------------
45   |         input2: Tensor<B, 3>,
     |         --------------------
46   |         input3: Tensor<B, 3>,
     |         --------------------
help: provide the argument
     |
2327 |         let output = model.forward(/* Tensor<NdArray, 3> */, /* Tensor<NdArray, 3> */, /* Tensor<NdArray, 3> */);

I can't see how the rank is ending up at 3?!

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could somebody help with this error-
[...]
I can't see how the rank is ending up at 3?!

The test case you defined doesn't match the ONNX model in onnx-tests/tests/gemm/gemm.onnx, which is the one you're using given the defined input in the build.rs`.

That model expects three inputs of shape [2, 2, 3], [2, 3, 4] and [2, 2, 4] respectively. Hence, the compiler correctly points out the error:

error[E0061]: this method takes 3 arguments but 2 arguments were supplied
    --> crates/burn-import/onnx-tests/tests/test_onnx.rs:2327:28
     |
2327 |         let output = model.forward(a, b);
[...]
note: method defined here
    --> /home/akshit/Storage/projects/burn/target/debug/build/onnx-tests-41879a2725f93fcf/out/model/gemm.rs:42:12
     |
42   |     pub fn forward(
     |            ^^^^^^^
43   |         &self,
44   |         input1: Tensor<B, 3>,
     |         --------------------
45   |         input2: Tensor<B, 3>,
     |         --------------------
46   |         input3: Tensor<B, 3>,
     |         --------------------

If you want to have tests for different node configurations, you might need to have additional ONNX models (to include accordingly in the onnx_tests).

@akshitgaur2005 akshitgaur2005 marked this pull request as ready for review March 1, 2025 12:45
Copy link

codecov bot commented Mar 1, 2025

Codecov Report

Attention: Patch coverage is 85.85859% with 28 lines in your changes missing coverage. Please review.

Project coverage is 82.31%. Comparing base (a6b5210) to head (f36368d).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/burn/node/gemm.rs 79.16% 25 Missing ⚠️
crates/onnx-ir/src/rank_inference.rs 83.33% 2 Missing ⚠️
crates/burn-import/src/onnx/to_burn.rs 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2841      +/-   ##
==========================================
+ Coverage   82.30%   82.31%   +0.01%     
==========================================
  Files         863      864       +1     
  Lines      116968   117153     +185     
==========================================
+ Hits        96268    96438     +170     
- Misses      20700    20715      +15     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@akshitgaur2005
Copy link
Contributor Author

The generated codecov report contains regression only on the files not even part of this PR, and I haven't removed any test??

New to these things btw, so sorry if it is something stupid

@laggui
Copy link
Member

laggui commented Mar 3, 2025

The generated codecov report contains regression only on the files not even part of this PR, and I haven't removed any test??

New to these things btw, so sorry if it is something stupid

You can safely ignore, the codecov bot probably detected those as changes since they were included from the merge main commit 😅

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! Glad to see you managed to make it work.

The current implementation is a bit restrictive, so we should either cover the whole scope or panic with a clear message for unsupported configurations (e.g., C is a scalar).

We should probably add some other tests since the Gemm node is only tested for all inputs A, B, and C of the same shape. Not strictly required to cover all the possible combinations since there are quite a few optional attributes and conditions in this particular case. For example, we could add another model (or include another node in the current ONNX) with one of the inputs being transposed and no C. If we add support for scalar C input, we should also include it in a test.

See also my comments on the implementation below 🙂

@akshitgaur2005 akshitgaur2005 requested a review from laggui March 5, 2025 09:34
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments left, mostly need to correctly handle the output rank and ideally add some tests for other node configurations.

Should be good to go after that! 🙂

};

#[test]
fn test_codegen_nodes() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I mentioned in the last review, we should add tests for some of the other configurations! Especially for the alpha or beta != 1.0 and no C.

At the very least for the codegen, the test ONNX model itself is not as strict of a requirement.

Copy link
Contributor Author

@akshitgaur2005 akshitgaur2005 Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How am I supposed to set the values of attrs in the test? And each test would have to have a different test_codegen?

Could you please tell how to do these things.

Copy link
Member

@laggui laggui Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, guess I should have been more precise 😅

a different test_codegen?

Yep, precisely! Basically each instance can be defined as a separate codegen test. See for example the gather node: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/src/burn/node/gather.rs#L150

They're separate test cases after all :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added two new test cases, one for alpha=beta=0.5 and one for no c. Although codegen test works fine, it does not care what the output is? Could you take a look at that

For example-

let expected =
            Tensor::<Backend, 2>::from_data(TensorData::from([[19.0, 22.], [43., 50.]]), &device);

changing the values in this makes no difference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants