From 3f65e280eaf3a2ac504a3a736909115b2ee49668 Mon Sep 17 00:00:00 2001 From: Akshit Gaur Date: Sat, 1 Mar 2025 18:11:11 +0530 Subject: [PATCH] GEMM Node added --- .../burn-import/onnx-tests/tests/gemm/gemm.onnx | Bin 231 -> 215 bytes .../burn-import/onnx-tests/tests/gemm/gemm.py | 12 ++++++------ .../burn-import/onnx-tests/tests/test_onnx.rs | 13 ++++++++----- crates/burn-import/src/burn/node/gemm.rs | 11 +++++++---- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/gemm/gemm.onnx b/crates/burn-import/onnx-tests/tests/gemm/gemm.onnx index f690f14df4f8038e4d8059de706f50270455b6d9..569c1615e38ef9d82942809dee8552cc9b70e4f6 100644 GIT binary patch delta 118 zcmaFPc%4y%gG-3d-_I{1-aR!hwJ5P9zsPFOM3JmWVJ=2TAwDi14n`pkE+!5p5RL*0 XI$;xZMi)#1s*FSzbYkIR5D)+W&>Rlm delta 134 zcmcc4_?%IMgG-3d-_I{1-aR!hwJ5P9zsTy$M3JmaaV|zjAt5dS4n`p!E+!5pC}xfV gN;siOFasr7KoZVq5+GeHNkEB6Y!Xf^TnqvN0L97=p8x;= diff --git a/crates/burn-import/onnx-tests/tests/gemm/gemm.py b/crates/burn-import/onnx-tests/tests/gemm/gemm.py index f97c8e3b8f..b67ed825dd 100644 --- a/crates/burn-import/onnx-tests/tests/gemm/gemm.py +++ b/crates/burn-import/onnx-tests/tests/gemm/gemm.py @@ -12,14 +12,14 @@ def create_gemm_model(output_path="gemm.onnx"): output_path (str): Path to save the ONNX model """ # Define input and output shapes - batch_size = 2 - m, k, n = 2, 3, 4 # A: (m, k), B: (k, n), C: (m, n) + # batch_size = 1 + m, k, n = 2, 2, 2 # A: (m, k), B: (k, n), C: (m, n) # Define the graph inputs and outputs - A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [batch_size, m, k]) - B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [batch_size, k, n]) - C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [batch_size, m, n]) - Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch_size, m, n]) + A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [m, k]) + B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [k, n]) + C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [m, n]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [m, n]) # Define Gemm node attributes alpha = 1.0 diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 668354967a..455fbedd80 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -2315,18 +2315,21 @@ mod tests { let b = Tensor::::from_data(TensorData::from([[5.0, 6.0], [7.0, 8.0]]), &device); + let c = + Tensor::::from_data(TensorData::from([[0.0, 1.0], [2.0, 3.0]]), &device); + // Expected result of matrix multiplication - // [1.0, 2.0] × [5.0, 6.0] = [1×5 + 2×7, 1×6 + 2×8] = [19.0, 22.0] - // [3.0, 4.0] × [7.0, 8.0] = [3×5 + 4×7, 3×6 + 4×8] = [43.0, 50.0] + // [1.0, 2.0] × [5.0, 6.0] = [1×5 + 2×7, 1×6 + 2×8] = [19.0 + 0.0, 22.0 + 1.0] = [19.0, 23.0] + // [3.0, 4.0] × [7.0, 8.0] = [3×5 + 4×7, 3×6 + 4×8] = [43.0 + 2.0, 50.0 + 3.0] = [45.0, 53.0] let expected = Tensor::::from_data( - TensorData::from([[19.0, 22.0], [43.0, 50.0]]), + TensorData::from([[19.0, 23.0], [45.0, 53.0]]), &device, ); // Run the model - let output = model.forward(a, b); + let output = model.forward(a, b, c); // Verify the output - output.to_data().assert_approx_eq(&expected.to_data(), 3); + output.to_data().assert_eq(&expected.to_data(), true); } } diff --git a/crates/burn-import/src/burn/node/gemm.rs b/crates/burn-import/src/burn/node/gemm.rs index fb16513e38..afb89bce21 100644 --- a/crates/burn-import/src/burn/node/gemm.rs +++ b/crates/burn-import/src/burn/node/gemm.rs @@ -54,14 +54,16 @@ impl NodeCodegen for GemmNode { quote! {#b} }; - let product = quote! {#a.matmul(#b)}; + let product = quote! {#a.clone().matmul(#b.clone())}; let scaled_product = quote! {#product * #alpha}; if let Some(ref c) = self.c { let c = scope.tensor_use_owned(c, node_position); quote! { - let #output = (#scaled_product) + (#c * #beta); + let mut d = (#scaled_product).zeros_like(); + d = #c.unsqueeze(); + let #output = (#scaled_product) + (d * #beta); } } else { quote! { @@ -128,8 +130,9 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self, tensor1: Tensor) -> Tensor { - "hello" + pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { + let tensor3 = tensor1.clone().matmul(tensor2.clone()) * 1f32; + tensor3 } } };