Skip to content

Commit

Permalink
GEMM Node added
Browse files Browse the repository at this point in the history
  • Loading branch information
akshitgaur2005 committed Mar 1, 2025
1 parent e4b7e0e commit 3f65e28
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
Binary file modified crates/burn-import/onnx-tests/tests/gemm/gemm.onnx
Binary file not shown.
12 changes: 6 additions & 6 deletions crates/burn-import/onnx-tests/tests/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ def create_gemm_model(output_path="gemm.onnx"):
output_path (str): Path to save the ONNX model
"""
# Define input and output shapes
batch_size = 2
m, k, n = 2, 3, 4 # A: (m, k), B: (k, n), C: (m, n)
# batch_size = 1
m, k, n = 2, 2, 2 # A: (m, k), B: (k, n), C: (m, n)

# Define the graph inputs and outputs
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [batch_size, m, k])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [batch_size, k, n])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [batch_size, m, n])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [batch_size, m, n])
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [m, k])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [k, n])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [m, n])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [m, n])

# Define Gemm node attributes
alpha = 1.0
Expand Down
13 changes: 8 additions & 5 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2315,18 +2315,21 @@ mod tests {
let b =
Tensor::<Backend, 2>::from_data(TensorData::from([[5.0, 6.0], [7.0, 8.0]]), &device);

let c =
Tensor::<Backend, 2>::from_data(TensorData::from([[0.0, 1.0], [2.0, 3.0]]), &device);

// Expected result of matrix multiplication
// [1.0, 2.0] × [5.0, 6.0] = [1×5 + 2×7, 1×6 + 2×8] = [19.0, 22.0]
// [3.0, 4.0] × [7.0, 8.0] = [3×5 + 4×7, 3×6 + 4×8] = [43.0, 50.0]
// [1.0, 2.0] × [5.0, 6.0] = [1×5 + 2×7, 1×6 + 2×8] = [19.0 + 0.0, 22.0 + 1.0] = [19.0, 23.0]
// [3.0, 4.0] × [7.0, 8.0] = [3×5 + 4×7, 3×6 + 4×8] = [43.0 + 2.0, 50.0 + 3.0] = [45.0, 53.0]
let expected = Tensor::<Backend, 2>::from_data(
TensorData::from([[19.0, 22.0], [43.0, 50.0]]),
TensorData::from([[19.0, 23.0], [45.0, 53.0]]),
&device,
);

// Run the model
let output = model.forward(a, b);
let output = model.forward(a, b, c);

// Verify the output
output.to_data().assert_approx_eq(&expected.to_data(), 3);
output.to_data().assert_eq(&expected.to_data(), true);
}
}
11 changes: 7 additions & 4 deletions crates/burn-import/src/burn/node/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for GemmNode {
quote! {#b}
};

let product = quote! {#a.matmul(#b)};
let product = quote! {#a.clone().matmul(#b.clone())};
let scaled_product = quote! {#product * #alpha};

if let Some(ref c) = self.c {
let c = scope.tensor_use_owned(c, node_position);

quote! {
let #output = (#scaled_product) + (#c * #beta);
let mut d = (#scaled_product).zeros_like();
d = #c.unsqueeze();
let #output = (#scaled_product) + (d * #beta);
}
} else {
quote! {
Expand Down Expand Up @@ -128,8 +130,9 @@ mod tests {
}

#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 2>) -> Tensor<B, 2> {
"hello"
pub fn forward(&self, tensor1: Tensor<B, 2>, tensor2: Tensor<B, 2>) -> Tensor<B, 2> {
let tensor3 = tensor1.clone().matmul(tensor2.clone()) * 1f32;
tensor3
}
}
};
Expand Down

0 comments on commit 3f65e28

Please sign in to comment.