Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle ndarray matmul broadcasting #1679

Merged
merged 3 commits into from
Apr 29, 2024
Merged

Handle ndarray matmul broadcasting #1679

merged 3 commits into from
Apr 29, 2024

Conversation

lancelet
Copy link
Contributor

@lancelet lancelet commented Apr 22, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #1646. Related to #1499

Changes

  • Use strides to map linear batch indices from the output back to the input arrays.

In matmul, we have arrays of shape [l_0, l_1, ..., l_N, m, k] @ [r_0, r_1, ..., r_N, k, n]. The dimensions l_i and r_i may be broadcast using the following (fairly standard) rules:

  1. If l_i == r_i then the output dimension size is this size (no broadcasting).
  2. If either l_i or r_i are 1, then the output dimension is max(l_i, r_i) and the array with the dimension equal to 1 is broadcast along axis i.

(The innermost two dimensions of the output are always [m, n], following standard matrix multiplication rules.)

When performing the stacked matrix multiplies within the batched matmul, it is necessary to be able to look up which matrices should be multiplied. In this PR, that is done using a stride approach:

  1. We iterate over a flattened index into the output array's batches (same as before).
  2. We convert this flattened batch index into a component batch index.
  3. We multiply the component batch index by batch strides for the left and right arrays to produce a flattened index into both.

This uses a standard stride lookup technique for broadcasting, where the stride for a broadcast dimension is zero.

Testing

  • Added test for a [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2] case which was previously returning incorrect dimensions.
  • Added tests for broadcast shape and stride calculations.

- Use strides to map linear batch indices from
  the output back to the input arrays.
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
let lhs_array = NdArray::<E>::float_reshape(lhs, Shape::new([num_l_batches, m, k])).array;
let rhs_array = NdArray::<E>::float_reshape(rhs, Shape::new([num_r_batches, k, n])).array;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally tried this with just into_shape, but that fails in some cases with an error about an incompatible layout. The original code was also calling float_reshape (below), so I've re-used that here.

#[derive(Debug, PartialEq)]
struct Strides {
strides: Vec<usize>,
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Vec here is a little frustrating. I gather that with nightly, it is possible to write [usize; {D - 2}] or similar. However, I've taken care to allocate these with a fixed capacity, so hopefully they're not too much slower than a sized array.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! 🎉

At first glance the implementation looks good to me. Will request @nathanielsimard since he implemented the previous version.

@laggui laggui requested a review from nathanielsimard April 22, 2024 12:51
Copy link

codecov bot commented Apr 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.54%. Comparing base (1cdceb5) to head (1b5b991).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1679      +/-   ##
==========================================
+ Coverage   86.51%   86.54%   +0.02%     
==========================================
  Files         696      696              
  Lines       81498    81653     +155     
==========================================
+ Hits        70506    70664     +158     
+ Misses      10992    10989       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@antimora
Copy link
Collaborator

It looks like we have two approvals. @nathanielsimard, when you have a chance, could you also review it?

@antimora antimora added the bug Something isn't working label Apr 26, 2024
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antimora antimora requested a review from louisfd April 26, 2024 18:31
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this bug!
I spotted two typos but overall looks good to me

@antimora antimora merged commit ab50143 into tracel-ai:main Apr 29, 2024
14 checks passed
nathanielsimard pushed a commit that referenced this pull request May 3, 2024
* Handle ndarray matmul broadcasting

- Use strides to map linear batch indices from
  the output back to the input arrays.

* Fix typos
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ndarray] General matmul broadcasting bug
6 participants