-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement batched serial getrs #2483
implement batched serial getrs #2483
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good but I would rewrite the analytical test to only use getrs
.
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A | ||
/// \param BlkSize [in] Block size of matrix A | ||
template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType> | ||
void impl_test_batched_getrs_analytical(const int N) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion here it would be better to not call getrf
but instead directly write the output of LU into A
and ipiv
and only test getrs
the idea is to isolate where the error could be coming from. That way this unit test will only fail if an issue is found in getrs
not in getrs
. For instance you could have A
, lu
, ipiv
and b
set as follows:
A=[[1, 1]
[1, -1]]
ipiv=[0,1]
lu=[[1, 1]]
[1, -2]
b=[[2]
[0]]
Then call directly getrs and check that the ouput is x=[[1],[1]]
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
77617ef
to
82ea131
Compare
Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
82ea131
to
9a81eee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update
This PR implements getrs function.
Following files are added:
KokkosBatched_Getrs_Serial_Impl.hpp
: Internal interfacesKokkosBatched_Getrs_Serial_Internal.hpp
: Implementation detailsKokkosBatched_Getrs.hpp
: APIsTest_Batched_SerialGetrs.hpp
: Unit tests for thatDetailed description
It solves a general N-by-N matrix
A
using the LU factorization computed by getrf.Here, the matrix has the following shape.
A
:(batch_count, n, n)
The N-by-N factorized matrix by getrf where
A = P * L * U
; the unit diagonal elements ofL
are not stored.IPIV
:(batch_count, n)
The pivot indices from getrf. for
0 <= i < n
, rowi
of the matrix was interchanged with row IPIV(i).Parallelization would be made in the following manner. This is efficient only when
A is given in
LayoutLeft
for GPUs andLayoutRight
for CPUs (parallelized over batch direction).Tests
A
and factorize it intoLU
withipiv
bygetrf
.Then, solve
A * x = b
withgetrs
to getx
, while keeping the originalb
inx_ref
. Finally, confirm thatA * x
is equal tob (=x_ref)
usinggem
.A
as follows to confirmLU
==A
.