Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ndarray] General matmul broadcasting bug #1646

Closed
laggui opened this issue Apr 16, 2024 · 3 comments · Fixed by #1679
Closed

[ndarray] General matmul broadcasting bug #1646

laggui opened this issue Apr 16, 2024 · 3 comments · Fixed by #1679
Assignees
Labels
bug Something isn't working

Comments

@laggui
Copy link
Member

laggui commented Apr 16, 2024

Describe the bug
Was double-checking our broadcasting support following #1499.

The general matmul implementation for the ndarray backend doesn't seem to compute the correct thing w.r.t. batch matrix broadcasting.

To Reproduce

use burn::{
    backend::{NdArray, Wgpu},
    tensor::Tensor,
};

fn main() {
    type B = NdArray;
    // type B = Wgpu;
    let device = Default::default();

    // (j x 1 x n x m) @ (1 x k x m x p) -> (j x k x n x p)
    let tensor_1 = Tensor::<B, 4>::ones([3, 1, 2, 4], &device);
    let tensor_2 = Tensor::<B, 4>::ones([1, 5, 4, 6], &device);
    let dims1 = tensor_1.dims();
    let dims2 = tensor_2.dims();
    let tensor_3 = tensor_1.matmul(tensor_2);
    println!("{:?} @ {:?} = {:?}", dims1, dims2, tensor_3.dims());
    // [3, 1, 2, 4] @ [1, 5, 4, 6] = [3, 5, 2, 6]
}

This code panics at tensor_1.matmul(tensor_2):

thread '<unnamed>' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:361:5:
collapse_axis: Index 3 must be less than axis length 3 for array with shape [3, 2, 4]
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread '<unnamed>' panicked at /home/laggui/.cargo/registry/src/index.crates.io-6f17d22bba15001f/ndarray-0.15.6/src/dimension/mod.rs:361:5:
collapse_axis: Index 4 must be less than axis length 3 for array with shape [3, 2, 4]

Expected behavior
The matmul should result in an output with shape [3, 5, 2, 6]. Works on wgpu.

@laggui laggui added the bug Something isn't working label Apr 16, 2024
@lancelet
Copy link
Contributor

lancelet commented Apr 22, 2024

EDIT: I have a WIP branch here: https://github.com/lancelet/burn/tree/matmul-broadcasting
In that branch, I use strides for the non-matrix dimensions (ie. the batch, depth, channel, or whatever) to handle arbitrary broadcasting.

It's getting the correct result for the above example, but it's failing on some autoregressive test cases with an error about an incompatible memory layout. This occurs when I try to flatten / reshape the arrays, in the same way that the original code does. Still investigating.

@lancelet
Copy link
Contributor

OK; PR raised here: #1679

I'm completely new to burn, so please let me know what extra tests, examples, docs, etc. should be added.

I saw a mention in another PR of improving the docs for matmul. Should I put that into this PR as well?

@laggui
Copy link
Member Author

laggui commented Apr 22, 2024

Hi @lancelet, thanks for looking into this! 🙂 I'll take a quick look at your PR for this issue specifically.

We do want to improve our documentation regarding the general broadcasting semantics in burn, and also matmul. If you want to take a stab at it you could open another PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants