-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle ndarray matmul broadcasting #1679
Conversation
- Use strides to map linear batch indices from the output back to the input arrays.
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap(); | ||
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap(); | ||
let lhs_array = NdArray::<E>::float_reshape(lhs, Shape::new([num_l_batches, m, k])).array; | ||
let rhs_array = NdArray::<E>::float_reshape(rhs, Shape::new([num_r_batches, k, n])).array; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally tried this with just into_shape
, but that fails in some cases with an error about an incompatible layout. The original code was also calling float_reshape
(below), so I've re-used that here.
#[derive(Debug, PartialEq)] | ||
struct Strides { | ||
strides: Vec<usize>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Vec
here is a little frustrating. I gather that with nightly, it is possible to write [usize; {D - 2}]
or similar. However, I've taken care to allocate these with a fixed capacity, so hopefully they're not too much slower than a sized array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! 🎉
At first glance the implementation looks good to me. Will request @nathanielsimard since he implemented the previous version.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1679 +/- ##
==========================================
+ Coverage 86.51% 86.54% +0.02%
==========================================
Files 696 696
Lines 81498 81653 +155
==========================================
+ Hits 70506 70664 +158
+ Misses 10992 10989 -3 ☔ View full report in Codecov by Sentry. |
It looks like we have two approvals. @nathanielsimard, when you have a chance, could you also review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No performance loss was detected for the matmul with somewhat large shapes: https://burn.dev/benchmarks/community-benchmarks/?version1=21a2c6553c7d4df0b3de830112888b39fab6a9d0&versionLabel1=Version+1&version2=-&versionLabel2=Version+2&backend=ndarray&device=All&name=All&os=Pop%21_OS+22.4.0+%28jammy%29+%5B64-bit%5D&sysHardware=Any&user=SharpPrecision&search=true
@louisfd would like your review on this before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this bug!
I spotted two typos but overall looks good to me
* Handle ndarray matmul broadcasting - Use strides to map linear batch indices from the output back to the input arrays. * Fix typos
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Fixes #1646. Related to #1499
Changes
In
matmul
, we have arrays of shape[l_0, l_1, ..., l_N, m, k] @ [r_0, r_1, ..., r_N, k, n]
. The dimensionsl_i
andr_i
may be broadcast using the following (fairly standard) rules:l_i == r_i
then the output dimension size is this size (no broadcasting).l_i
orr_i
are 1, then the output dimension ismax(l_i, r_i)
and the array with the dimension equal to 1 is broadcast along axisi
.(The innermost two dimensions of the output are always
[m, n]
, following standard matrix multiplication rules.)When performing the stacked matrix multiplies within the batched
matmul
, it is necessary to be able to look up which matrices should be multiplied. In this PR, that is done using a stride approach:This uses a standard stride lookup technique for broadcasting, where the stride for a broadcast dimension is zero.
Testing
[2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2]
case which was previously returning incorrect dimensions.