Skip to content

Commit 249125b

Browse files
amontoisonmaleadt
andauthored
Update the support library for oneAPI 2024 (#389)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 70bce16 commit 249125b

10 files changed

+1328
-461
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
LocalPreferences.toml
2+
Manifest.toml

deps/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
1212
oneAPI_Support_Headers_jll = "24f86df5-245d-5634-a4cc-32433d9800b3"
1313

1414
[compat]
15-
oneAPI_Support_Headers_jll = "=2023.0.0"
15+
oneAPI_Support_Headers_jll = "=2024.0.0"

deps/generate_interfaces.jl

+10-8
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ function generate_headers(library::String, filename::String, output::String)
7171

7272
if library == "blas"
7373
header = replace(header, "compute_mode mode = MKL_BLAS_COMPUTE_MODE" => "")
74+
header = replace(header, "index_base base=index_base::zero" => "onemklIndex base")
7475

7576
header = replace(header, "sycl::buffer<Ta> &" => "Ta *")
7677
header = replace(header, "sycl::buffer<Tb> &" => "Tb *")
@@ -344,6 +345,7 @@ function generate_cpp(library::String, filename::String, output::String)
344345
for type in ("onemklTranspose", "onemklSide", "onemklUplo", "onemklDiag", "onemklGenerate",
345346
"onemklJob", "onemklJobsvd", "onemklCompz", "onemklRangev", "onemklIndex", "onemklProperty")
346347
parameters = replace(parameters, Regex("$type ([a-z_]+),") => SubstitutionString("convert(\\1),"))
348+
parameters = replace(parameters, Regex(", $type ([a-z_]+)") => SubstitutionString(", convert(\\1)"))
347349
end
348350
parameters = replace(parameters, r" >([a-z]+)" => s" >(\1)")
349351
parameters = replace(parameters, r" \*>([a-z]+)" => s"*>(\1)")
@@ -375,7 +377,7 @@ end
375377

376378
generate_headers("lapack", lapack, "onemkl_lapack.h")
377379
generate_headers("blas", blas, "onemkl_blas.h")
378-
generate_headers("sparse", sparse, "onemkl_sparse.h")
380+
# generate_headers("sparse", sparse, "onemkl_sparse.h")
379381

380382
io = open("src/onemkl.h", "w")
381383
headers_prologue = read("onemkl_prologue.h", String)
@@ -386,16 +388,16 @@ write(io, headers_blas)
386388
headers_lapack = read("onemkl_lapack.h", String)
387389
write(io, "// LAPACK\n")
388390
write(io, headers_lapack)
389-
headers_sparse = read("onemkl_sparse.h", String)
390-
write(io, "// SPARSE\n")
391-
write(io, headers_sparse)
391+
# headers_sparse = read("onemkl_sparse.h", String)
392+
# write(io, "// SPARSE\n")
393+
# write(io, headers_sparse)
392394
headers_epilogue = read("onemkl_epilogue.h", String)
393395
write(io, headers_epilogue)
394396
close(io)
395397

396398
generate_cpp("lapack", lapack, "onemkl_lapack.cpp")
397399
generate_cpp("blas", blas, "onemkl_blas.cpp")
398-
generate_cpp("sparse", sparse, "onemkl_sparse.cpp")
400+
# generate_cpp("sparse", sparse, "onemkl_sparse.cpp")
399401

400402
io = open("src/onemkl.cpp", "w")
401403
cpp_prologue = read("onemkl_prologue.cpp", String)
@@ -406,9 +408,9 @@ write(io, cpp_blas)
406408
cpp_lapack = read("onemkl_lapack.cpp", String)
407409
write(io, "// LAPACK\n")
408410
write(io, cpp_lapack)
409-
cpp_sparse = read("onemkl_sparse.cpp", String)
410-
write(io, "// SPARSE\n")
411-
write(io, cpp_sparse)
411+
# cpp_sparse = read("onemkl_sparse.cpp", String)
412+
# write(io, "// SPARSE\n")
413+
# write(io, cpp_sparse)
412414
cpp_epilogue = read("onemkl_epilogue.cpp", String)
413415
write(io, cpp_epilogue)
414416
close(io)

deps/src/onemkl.cpp

+60-16
Original file line numberDiff line numberDiff line change
@@ -1245,50 +1245,50 @@ extern "C" int onemklZdotu(syclQueue_t device_queue, int64_t n, double _Complex
12451245
return 0;
12461246
}
12471247

1248-
extern "C" int onemklSiamax(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result) {
1249-
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
1248+
extern "C" int onemklSiamax(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result, onemklIndex base) {
1249+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result, convert(base));
12501250
__FORCE_MKL_FLUSH__(status);
12511251
return 0;
12521252
}
12531253

1254-
extern "C" int onemklDiamax(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result) {
1255-
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
1254+
extern "C" int onemklDiamax(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result, onemklIndex base) {
1255+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result, convert(base));
12561256
__FORCE_MKL_FLUSH__(status);
12571257
return 0;
12581258
}
12591259

1260-
extern "C" int onemklCiamax(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t *result) {
1261-
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<std::complex<float>*>(x), incx, result);
1260+
extern "C" int onemklCiamax(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t *result, onemklIndex base) {
1261+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<std::complex<float>*>(x), incx, result, convert(base));
12621262
__FORCE_MKL_FLUSH__(status);
12631263
return 0;
12641264
}
12651265

1266-
extern "C" int onemklZiamax(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t *result) {
1267-
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<std::complex<double>*>(x), incx, result);
1266+
extern "C" int onemklZiamax(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t *result, onemklIndex base) {
1267+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, reinterpret_cast<std::complex<double>*>(x), incx, result, convert(base));
12681268
__FORCE_MKL_FLUSH__(status);
12691269
return 0;
12701270
}
12711271

1272-
extern "C" int onemklSiamin(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result) {
1273-
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
1272+
extern "C" int onemklSiamin(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result, onemklIndex base) {
1273+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result, convert(base));
12741274
__FORCE_MKL_FLUSH__(status);
12751275
return 0;
12761276
}
12771277

1278-
extern "C" int onemklDiamin(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result) {
1279-
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
1278+
extern "C" int onemklDiamin(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result, onemklIndex base) {
1279+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result, convert(base));
12801280
__FORCE_MKL_FLUSH__(status);
12811281
return 0;
12821282
}
12831283

1284-
extern "C" int onemklCiamin(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t *result) {
1285-
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<std::complex<float>*>(x), incx, result);
1284+
extern "C" int onemklCiamin(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t *result, onemklIndex base) {
1285+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<std::complex<float>*>(x), incx, result, convert(base));
12861286
__FORCE_MKL_FLUSH__(status);
12871287
return 0;
12881288
}
12891289

1290-
extern "C" int onemklZiamin(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t *result) {
1291-
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<std::complex<double>*>(x), incx, result);
1290+
extern "C" int onemklZiamin(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t *result, onemklIndex base) {
1291+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, reinterpret_cast<std::complex<double>*>(x), incx, result, convert(base));
12921292
__FORCE_MKL_FLUSH__(status);
12931293
return 0;
12941294
}
@@ -2244,6 +2244,50 @@ extern "C" int onemklZgetrf_batch(syclQueue_t device_queue, int64_t m, int64_t n
22442244
return 0;
22452245
}
22462246

2247+
extern "C" int64_t onemklSgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda) {
2248+
int64_t scratchpad_size = oneapi::mkl::lapack::getrfnp_scratchpad_size<float>(device_queue->val, m, n, lda);
2249+
return scratchpad_size;
2250+
}
2251+
2252+
extern "C" int64_t onemklDgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda) {
2253+
int64_t scratchpad_size = oneapi::mkl::lapack::getrfnp_scratchpad_size<double>(device_queue->val, m, n, lda);
2254+
return scratchpad_size;
2255+
}
2256+
2257+
extern "C" int64_t onemklCgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda) {
2258+
int64_t scratchpad_size = oneapi::mkl::lapack::getrfnp_scratchpad_size<std::complex<float>>(device_queue->val, m, n, lda);
2259+
return scratchpad_size;
2260+
}
2261+
2262+
extern "C" int64_t onemklZgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda) {
2263+
int64_t scratchpad_size = oneapi::mkl::lapack::getrfnp_scratchpad_size<std::complex<double>>(device_queue->val, m, n, lda);
2264+
return scratchpad_size;
2265+
}
2266+
2267+
extern "C" int onemklCgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda, float _Complex *scratchpad, int64_t scratchpad_size) {
2268+
auto status = oneapi::mkl::lapack::getrfnp(device_queue->val, m, n, reinterpret_cast<std::complex<float>*>(a), lda, reinterpret_cast<std::complex<float>*>(scratchpad), scratchpad_size);
2269+
__FORCE_MKL_FLUSH__(status);
2270+
return 0;
2271+
}
2272+
2273+
extern "C" int onemklDgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double *scratchpad, int64_t scratchpad_size) {
2274+
auto status = oneapi::mkl::lapack::getrfnp(device_queue->val, m, n, a, lda, scratchpad, scratchpad_size);
2275+
__FORCE_MKL_FLUSH__(status);
2276+
return 0;
2277+
}
2278+
2279+
extern "C" int onemklSgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float *scratchpad, int64_t scratchpad_size) {
2280+
auto status = oneapi::mkl::lapack::getrfnp(device_queue->val, m, n, a, lda, scratchpad, scratchpad_size);
2281+
__FORCE_MKL_FLUSH__(status);
2282+
return 0;
2283+
}
2284+
2285+
extern "C" int onemklZgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda, double _Complex *scratchpad, int64_t scratchpad_size) {
2286+
auto status = oneapi::mkl::lapack::getrfnp(device_queue->val, m, n, reinterpret_cast<std::complex<double>*>(a), lda, reinterpret_cast<std::complex<double>*>(scratchpad), scratchpad_size);
2287+
__FORCE_MKL_FLUSH__(status);
2288+
return 0;
2289+
}
2290+
22472291
extern "C" int64_t onemklSgetrfnp_batch_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t lda, int64_t stride_a, int64_t batch_size) {
22482292
int64_t scratchpad_size = oneapi::mkl::lapack::getrfnp_batch_scratchpad_size<float>(device_queue->val, m, n, lda, stride_a, batch_size);
22492293
return scratchpad_size;

deps/src/onemkl.h

+36-8
Original file line numberDiff line numberDiff line change
@@ -620,25 +620,29 @@ int onemklCdotu(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t
620620
int onemklZdotu(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, double
621621
_Complex *y, int64_t incy, double _Complex *result);
622622

623-
int onemklSiamax(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result);
623+
int onemklSiamax(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result,
624+
onemklIndex base);
624625

625-
int onemklDiamax(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result);
626+
int onemklDiamax(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result,
627+
onemklIndex base);
626628

627629
int onemklCiamax(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t
628-
*result);
630+
*result, onemklIndex base);
629631

630632
int onemklZiamax(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t
631-
*result);
633+
*result, onemklIndex base);
632634

633-
int onemklSiamin(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result);
635+
int onemklSiamin(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, int64_t *result,
636+
onemklIndex base);
634637

635-
int onemklDiamin(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result);
638+
int onemklDiamin(syclQueue_t device_queue, int64_t n, double *x, int64_t incx, int64_t *result,
639+
onemklIndex base);
636640

637641
int onemklCiamin(syclQueue_t device_queue, int64_t n, float _Complex *x, int64_t incx, int64_t
638-
*result);
642+
*result, onemklIndex base);
639643

640644
int onemklZiamin(syclQueue_t device_queue, int64_t n, double _Complex *x, int64_t incx, int64_t
641-
*result);
645+
*result, onemklIndex base);
642646

643647
int onemklSasum(syclQueue_t device_queue, int64_t n, float *x, int64_t incx, float *result);
644648

@@ -1192,6 +1196,30 @@ int onemklZgetrf_batch(syclQueue_t device_queue, int64_t m, int64_t n, double _C
11921196
lda, int64_t stride_a, int64_t *ipiv, int64_t stride_ipiv, int64_t
11931197
batch_size, double _Complex *scratchpad, int64_t scratchpad_size);
11941198

1199+
int64_t onemklSgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t
1200+
lda);
1201+
1202+
int64_t onemklDgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t
1203+
lda);
1204+
1205+
int64_t onemklCgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t
1206+
lda);
1207+
1208+
int64_t onemklZgetrfnp_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n, int64_t
1209+
lda);
1210+
1211+
int onemklCgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, float _Complex *a, int64_t lda,
1212+
float _Complex *scratchpad, int64_t scratchpad_size);
1213+
1214+
int onemklDgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, double *a, int64_t lda, double
1215+
*scratchpad, int64_t scratchpad_size);
1216+
1217+
int onemklSgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, float *a, int64_t lda, float
1218+
*scratchpad, int64_t scratchpad_size);
1219+
1220+
int onemklZgetrfnp(syclQueue_t device_queue, int64_t m, int64_t n, double _Complex *a, int64_t lda,
1221+
double _Complex *scratchpad, int64_t scratchpad_size);
1222+
11951223
int64_t onemklSgetrfnp_batch_scratchpad_size(syclQueue_t device_queue, int64_t m, int64_t n,
11961224
int64_t lda, int64_t stride_a, int64_t batch_size);
11971225

0 commit comments

Comments
 (0)