You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Was double-checking our broadcasting support following #1499.
The general matmul implementation for the ndarray backend doesn't seem to compute the correct thing w.r.t. batch matrix broadcasting.
To Reproduce
use burn::{
backend::{NdArray,Wgpu},
tensor::Tensor,};fnmain(){typeB = NdArray;// type B = Wgpu;let device = Default::default();// (j x 1 x n x m) @ (1 x k x m x p) -> (j x k x n x p)let tensor_1 = Tensor::<B,4>::ones([3,1,2,4],&device);let tensor_2 = Tensor::<B,4>::ones([1,5,4,6],&device);let dims1 = tensor_1.dims();let dims2 = tensor_2.dims();let tensor_3 = tensor_1.matmul(tensor_2);println!("{:?} @ {:?} = {:?}", dims1, dims2, tensor_3.dims());// [3, 1, 2, 4] @ [1, 5, 4, 6] = [3, 5, 2, 6]}
This code panics at tensor_1.matmul(tensor_2):
thread '<unnamed>' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:361:5:
collapse_axis: Index 3 must be less than axis length 3 for array with shape [3, 2, 4]
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread '<unnamed>' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:361:5:
collapse_axis: Index 4 must be less than axis length 3 for array with shape [3, 2, 4]
Expected behavior
The matmul should result in an output with shape [3, 5, 2, 6]. Works on wgpu.
The text was updated successfully, but these errors were encountered:
EDIT: I have a WIP branch here: https://github.com/lancelet/burn/tree/matmul-broadcasting
In that branch, I use strides for the non-matrix dimensions (ie. the batch, depth, channel, or whatever) to handle arbitrary broadcasting.
It's getting the correct result for the above example, but it's failing on some autoregressive test cases with an error about an incompatible memory layout. This occurs when I try to flatten / reshape the arrays, in the same way that the original code does. Still investigating.
Hi @lancelet, thanks for looking into this! 🙂 I'll take a quick look at your PR for this issue specifically.
We do want to improve our documentation regarding the general broadcasting semantics in burn, and also matmul. If you want to take a stab at it you could open another PR!
Describe the bug
Was double-checking our broadcasting support following #1499.
The general matmul implementation for the ndarray backend doesn't seem to compute the correct thing w.r.t. batch matrix broadcasting.
To Reproduce
This code panics at
tensor_1.matmul(tensor_2)
:Expected behavior
The matmul should result in an output with shape
[3, 5, 2, 6]
. Works on wgpu.The text was updated successfully, but these errors were encountered: