|
| 1 | +/*! |
| 2 | + * Copyright (c) 2017 by Contributors |
| 3 | + * \file c_lapack_api.h |
| 4 | + * \brief Unified interface for LAPACK calls from within mxnet. |
| 5 | + * Purpose is to hide the platform specific differences. |
| 6 | + */ |
| 7 | +#ifndef MXNET_C_LAPACK_API_H_ |
| 8 | +#define MXNET_C_LAPACK_API_H_ |
| 9 | + |
| 10 | +// Manually maintained list of LAPACK interfaces that can be used |
| 11 | +// within MXNET. Conventions: |
| 12 | +// - Interfaces must be compliant with lapacke.h in terms of signature and |
| 13 | +// naming conventions so wrapping a function "foo" which has the |
| 14 | +// signature |
| 15 | +// lapack_int LAPACKE_foo(int, char, lapack_int, float* , lapack_int) |
| 16 | +// within lapacke.h should result in a wrapper with the following signature |
| 17 | +// int MXNET_LAPACK_foo(int, char, int, float* , int) |
| 18 | +// Note that function signatures in lapacke.h will always have as first |
| 19 | +// argument the storage order (row/col-major). All wrappers have to support |
| 20 | +// that argument. The underlying fortran functions will always assume a |
| 21 | +// column-major layout. It is the responsibility of the wrapper function |
| 22 | +// to handle the (usual) case that it is called with data in row-major |
| 23 | +// format, either by doing appropriate transpositions explicitly or using |
| 24 | +// transposition options of the underlying fortran function. |
| 25 | +// - It is ok to assume that matrices are stored in contiguous memory |
| 26 | +// (which removes the need to do special handling for lda/ldb parameters |
| 27 | +// and enables us to save additional matrix transpositions around |
| 28 | +// the fortran calls). |
| 29 | +// - It is desired to add some basic checking in the C++-wrappers in order |
| 30 | +// to catch simple mistakes when calling these wrappers. |
| 31 | +// - Must support compilation without lapack-package but issue runtime error in this case. |
| 32 | + |
| 33 | +#include <dmlc/logging.h> |
| 34 | + |
| 35 | +extern "C" { |
| 36 | + // Fortran signatures |
| 37 | + #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ |
| 38 | + void func##_(char* uplo, int* n, dtype* a, int* lda, int *info); |
| 39 | + |
| 40 | + MXNET_LAPACK_FSIGNATURE1(spotrf, float) |
| 41 | + MXNET_LAPACK_FSIGNATURE1(dpotrf, double) |
| 42 | + MXNET_LAPACK_FSIGNATURE1(spotri, float) |
| 43 | + MXNET_LAPACK_FSIGNATURE1(dpotri, double) |
| 44 | +} |
| 45 | + |
| 46 | +#define MXNET_LAPACK_ROW_MAJOR 101 |
| 47 | +#define MXNET_LAPACK_COL_MAJOR 102 |
| 48 | + |
| 49 | +#define CHECK_LAPACK_CONTIGUOUS(a, b) \ |
| 50 | + CHECK_EQ(a, b) << "non contiguous memory for array in lapack call"; |
| 51 | + |
| 52 | +#define CHECK_LAPACK_UPLO(a) \ |
| 53 | + CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in lapack call"; |
| 54 | + |
| 55 | +inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : 'U') : uplo; } |
| 56 | + |
| 57 | +#if MXNET_USE_LAPACK |
| 58 | + |
| 59 | + #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ |
| 60 | + inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda ) { \ |
| 61 | + CHECK_LAPACK_CONTIGUOUS(n, lda); \ |
| 62 | + CHECK_LAPACK_UPLO(uplo); \ |
| 63 | + char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ |
| 64 | + int ret(0); \ |
| 65 | + func##_(&o, &n, a, &lda, &ret); \ |
| 66 | + return ret; \ |
| 67 | + } |
| 68 | + MXNET_LAPACK_CWRAPPER1(spotrf, float) |
| 69 | + MXNET_LAPACK_CWRAPPER1(dpotrf, double) |
| 70 | + MXNET_LAPACK_CWRAPPER1(spotri, float) |
| 71 | + MXNET_LAPACK_CWRAPPER1(dpotri, double) |
| 72 | + |
| 73 | +#else |
| 74 | + // use pragma message instead of warning |
| 75 | + #pragma message("Warning: lapack usage not enabled, linalg-operators will be not available." \ |
| 76 | + " Build with USE_LAPACK=1 to get lapack functionalities.") |
| 77 | + |
| 78 | + // Define compilable stubs. |
| 79 | + #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ |
| 80 | + inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda ) { \ |
| 81 | + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ |
| 82 | + return 1; \ |
| 83 | + } |
| 84 | + MXNET_LAPACK_CWRAPPER1(spotrf, float) |
| 85 | + MXNET_LAPACK_CWRAPPER1(dpotrf, double) |
| 86 | + MXNET_LAPACK_CWRAPPER1(spotri, float) |
| 87 | + MXNET_LAPACK_CWRAPPER1(dpotri, double) |
| 88 | + |
| 89 | +#endif |
| 90 | + |
| 91 | +#endif // MXNET_C_LAPACK_API_H_ |
0 commit comments