31 #ifndef __TASMANIAN_BLAS_WRAPPERS_HPP
32 #define __TASMANIAN_BLAS_WRAPPERS_HPP
48 #ifndef __TASMANIAN_DOXYGEN_SKIP
52 double dnrm2_(
const int *N,
const double *x,
const int *incx);
53 void dswap_(
const int *N,
double *x,
const int *incx,
double *y,
const int *incy);
54 void dscal_(
const int *N,
const double *alpha,
const double *x,
const int *incx);
56 void dgemv_(
const char *transa,
const int *M,
const int *N,
const double *alpha,
const double *A,
const int *lda,
57 const double *x,
const int *incx,
const double *beta,
const double *y,
const int *incy);
58 void dtrsv_(
const char *uplo,
const char *trans,
const char *diag,
const int *N,
const double *A,
const int *lda,
59 double *x,
const int *incx);
61 void dgemm_(
const char* transa,
const char* transb,
const int *m,
const int *n,
const int *k,
const double *alpha,
62 const double *A,
const int *lda,
const double *B,
const int *ldb,
const double *beta,
const double *C,
const int *ldc);
63 void dtrsm_(
const char *side,
const char *uplo,
const char *trans,
const char *diag,
const int *M,
const int *N,
64 const double *alpha,
const double *A,
const int *lda,
double *B,
const int *ldb);
65 void ztrsm_(
const char *side,
const char *uplo,
const char *trans,
const char *diag,
const int *M,
const int *N,
66 const std::complex<double> *alpha,
const std::complex<double> *A,
const int *lda, std::complex<double> *B,
const int *ldb);
69 void dgetrf_(
const int *M,
const int *N,
double *A,
const int *lda,
int *ipiv,
int *info);
70 void dgetrs_(
const char *trans,
const int *N,
const int *nrhs,
const double *A,
const int *lda,
const int *ipiv,
double *B,
const int *ldb,
int *info);
72 void dgels_(
const char *trans,
const int *M,
const int *N,
const int *nrhs,
double *A,
const int *lda,
73 double *B,
const int *ldb,
double *work,
int *lwork,
int *info);
74 void zgels_(
const char *trans,
const int *M,
const int *N,
const int *nrhs, std::complex<double> *A,
const int *lda,
75 std::complex<double> *B,
const int *ldb, std::complex<double> *work,
int *lwork,
int *info);
77 void dstebz_(
const char *range,
const char *order,
const int *N,
const double *vl,
const double *vu,
const int *il,
const int *iu,
const double *abstol,
78 const double D[],
const double E[],
int *M,
int *nsplit,
double W[],
int iblock[],
int isplit[],
double work[],
int iwork[],
int *info);
79 void dsteqr_(
const char *compz,
const int *N,
double D[],
double E[],
double Z[],
const int *ldz,
double work[],
int *info);
80 void dsterf_(
const int *N,
double D[],
double E[],
int *info);
82 #ifdef Tasmanian_BLAS_HAS_ZGELQ
83 void dgelq_(
const int *M,
const int *N,
double *A,
const int *lda,
double *T,
int const *Tsize,
double *work,
int const *lwork,
int *info);
84 void dgemlq_(
const char *side,
const char *trans,
const int *M,
const int *N,
const int *K,
double const *A,
int const *lda,
85 double const *T,
int const *Tsize,
double C[],
int const *ldc,
double *work,
int const *lwork,
int *info);
86 void zgelq_(
const int *M,
const int *N, std::complex<double> *A,
const int *lda, std::complex<double> *T,
int const *Tsize,
87 std::complex<double> *work,
int const *lwork,
int *info);
88 void zgemlq_(
const char *side,
const char *trans,
const int *M,
const int *N,
const int *K, std::complex<double>
const *A,
int const *lda,
89 std::complex<double>
const *T,
int const *Tsize, std::complex<double> C[],
int const *ldc, std::complex<double> *work,
int const *lwork,
int *info);
103 inline double norm2(
int N,
double const x[],
int incx){
104 return dnrm2_(&N, x, &incx);
110 inline void vswap(
int N,
double x[],
int incx,
double y[],
int incy){
111 dswap_(&N, x, &incx, y, &incy);
117 inline void scal(
int N,
double alpha,
double x[],
int incx){
118 dscal_(&N, &alpha, x, &incx);
124 inline void gemv(
char trans,
int M,
int N,
double alpha,
double const A[],
int lda,
double const x[],
int incx,
125 double beta,
double y[],
int incy){
126 dgemv_(&trans, &M, &N, &alpha, A, &lda, x, &incx, &beta, y, &incy);
132 inline void trsv(
char uplo,
char trans,
char diag,
int N,
double const A[],
int lda,
double x[],
int incx){
133 dtrsv_(&uplo, &trans, &diag, &N, A, &lda, x, &incx);
139 inline void gemm(
char transa,
char transb,
int M,
int N,
int K,
double alpha,
double const A[],
int lda,
double const B[],
int ldb,
140 double beta,
double C[],
int ldc){
141 dgemm_(&transa, &transb, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
147 inline void trsm(
char side,
char uplo,
char trans,
char diag,
int M,
int N,
double alpha,
double const A[],
int lda,
double B[],
int ldb){
148 dtrsm_(&side, &uplo, &trans, &diag, &M, &N, &alpha, A, &lda, B, &ldb);
154 inline void trsm(
char side,
char uplo,
char trans,
char diag,
int M,
int N, std::complex<double> alpha,
155 std::complex<double>
const A[],
int lda, std::complex<double> B[],
int ldb){
156 ztrsm_(&side, &uplo, &trans, &diag, &M, &N, &alpha, A, &lda, B, &ldb);
162 inline void getrf(
int M,
int N,
double A[],
int lda,
int ipiv[]){
164 dgetrf_(&M, &N, A, &lda, ipiv, &info);
165 if (info != 0)
throw std::runtime_error(std::string(
"Lapack dgetrf_ exited with code: ") + std::to_string(info));
171 inline void getrs(
char trans,
int N,
int nrhs,
double const A[],
int lda,
int const ipiv[],
double B[],
int ldb){
173 dgetrs_(&trans, &N, &nrhs, A, &lda, ipiv, B, &ldb, &info);
174 if (info != 0)
throw std::runtime_error(std::string(
"Lapack dgetrs_ exited with code: ") + std::to_string(info));
180 inline void gels(
char trans,
int M,
int N,
int nrhs,
double A[],
int lda,
double B[],
int ldb,
double work[],
int lwork){
182 dgels_(&trans, &M, &N, &nrhs, A, &lda, B, &ldb, work, &lwork, &info);
185 throw std::runtime_error(std::string(
"Lapack dgels_ solve-stage exited with code: ") + std::to_string(info));
187 throw std::runtime_error(std::string(
"Lapack dgels_ infer-worksize-stage exited with code: ") + std::to_string(info));
194 inline void gels(
char trans,
int M,
int N,
int nrhs, std::complex<double> A[],
int lda, std::complex<double> B[],
int ldb, std::complex<double> work[],
int lwork){
196 zgels_(&trans, &M, &N, &nrhs, A, &lda, B, &ldb, work, &lwork, &info);
199 throw std::runtime_error(std::string(
"Lapack zgels_ solve-stage exited with code: ") + std::to_string(info));
201 throw std::runtime_error(std::string(
"Lapack zgels_ infer-worksize-stage exited with code: ") + std::to_string(info));
208 inline void stebz(
char range,
char order,
int N,
double vl,
double vu,
int il,
int iu,
double abstol,
double D[],
double E[],
209 int& M,
int& nsplit,
double W[],
int iblock[],
int isplit[],
double work[],
int iwork[]) {
211 dstebz_(&range, &order, &N, &vl, &vu, &il, &iu, &abstol, D, E, &M, &nsplit, W, iblock, isplit, work, iwork, &info);
214 throw std::runtime_error(
216 "Lapack dstebz_ failed to converge for some eigenvalues and exited with code: ") +
217 std::to_string(info));
218 }
else if (info == 4) {
219 throw std::runtime_error(
220 std::string(
"Lapack dstebz_ used a Gershgorin interval that was too small and exited with code: ") +
221 std::to_string(info));
222 }
else if (info > 4) {
223 throw std::runtime_error(
224 std::string(
"Lapack dstebz_ failed and exited with code: ") +
225 std::to_string(info));
227 throw std::runtime_error(
229 "Lapack dstebz_ had an illegal value at argument number: ") +
230 std::to_string(-info));
238 inline void steqr(
char compz,
int N,
double D[],
double E[],
double Z[],
int ldz,
double work[]) {
240 dsteqr_(&compz, &N, D, E, Z, &ldz, work, &info);
243 throw std::runtime_error(
244 std::string(
"Lapack dsteqr_ failed to converge for some eigenvalues and exited with code: ") +
245 std::to_string(info));
247 throw std::runtime_error(
248 std::string(
"Lapack dsteqr_ had an illegal value at argument number: ") +
249 std::to_string(-info));
257 inline void sterf(
int N,
double D[],
double E[]) {
259 dsterf_(&N, D, E, &info);
262 throw std::runtime_error(
263 std::string(
"Lapack dsteqr_ failed to converge for some eigenvalues and exited with code: ") +
264 std::to_string(info));
266 throw std::runtime_error(
267 std::string(
"Lapack dsteqr_ had an illegal value at argument number: ") +
268 std::to_string(-info));
272 #ifdef Tasmanian_BLAS_HAS_ZGELQ
277 inline void geql(
int M,
int N,
double A[],
int lda,
double T[],
int Tsize,
double work[],
int lwork){
279 dgelq_(&M, &N, A, &lda, T, &Tsize, work, &lwork, &info);
282 throw std::runtime_error(std::string(
"Lapack dgeql_ factorize-stage exited with code: ") + std::to_string(info));
284 throw std::runtime_error(std::string(
"Lapack dgeql_ infer-worksize-stage exited with code: ") + std::to_string(info));
291 inline void gemlq(
char side,
char trans,
int M,
int N,
int K,
double const A[],
int lda,
double const T[],
int Tsize,
292 double C[],
int ldc,
double work[],
int lwork){
294 dgemlq_(&side, &trans, &M, &N, &K, A, &lda, T, &Tsize, C, &ldc, work, &lwork, &info);
297 throw std::runtime_error(std::string(
"Lapack dgemlq_ compute-stage exited with code: ") + std::to_string(info));
299 throw std::runtime_error(std::string(
"Lapack dgemlq_ infer-worksize-stage exited with code: ") + std::to_string(info));
306 inline void geql(
int M,
int N, std::complex<double> A[],
int lda, std::complex<double> T[],
int Tsize, std::complex<double> work[],
int lwork){
308 zgelq_(&M, &N, A, &lda, T, &Tsize, work, &lwork, &info);
311 throw std::runtime_error(std::string(
"Lapack zgeql_ factorize-stage exited with code: ") + std::to_string(info));
313 throw std::runtime_error(std::string(
"Lapack zgeql_ infer-worksize-stage exited with code: ") + std::to_string(info));
320 inline void gemlq(
char side,
char trans,
int M,
int N,
int K, std::complex<double>
const A[],
int lda, std::complex<double>
const T[],
int Tsize,
321 std::complex<double> C[],
int ldc, std::complex<double> work[],
int lwork){
323 zgemlq_(&side, &trans, &M, &N, &K, A, &lda, T, &Tsize, C, &ldc, work, &lwork, &info);
326 throw std::runtime_error(std::string(
"Lapack zgemlq_ compute-stage exited with code: ") + std::to_string(info));
328 throw std::runtime_error(std::string(
"Lapack zgemlq_ infer-worksize-stage exited with code: ") + std::to_string(info));
341 T nrm =
norm2(N, x, 1);
352 inline void denseMultiply(
int M,
int N,
int K, T alpha,
const T A[],
const T B[], T beta, T C[]){
355 gemm(
'N',
'N', M, N, K, alpha, A, M, B, K, beta, C, M);
357 gemv(
'N', M, K, alpha, A, M, B, 1, beta, C, 1);
360 gemv(
'T', K, N, alpha, B, K, A, 1, beta, C, 1);
373 for(
size_t i=0; i<static_cast<size_t>(N) *
static_cast<size_t>(M); i++) A[i] = std::conj(A[i]);
384 constexpr
inline char get_trans(std::complex<double>){
return 'C'; }
391 template<
typename scalar_type>
392 inline void solveLS(
char trans,
int N,
int M, scalar_type A[], scalar_type b[],
int nrhs = 1){
393 std::vector<scalar_type> work(1);
394 int n = (trans ==
'N') ? N : M;
395 int m = (trans ==
'N') ? M : N;
396 char effective_trans = (trans ==
'N') ? trans :
get_trans(
static_cast<scalar_type
>(0.0));
398 TasBLAS::gels(effective_trans, n, m, nrhs, A, n, b, N, work.data(), -1);
399 work.resize(
static_cast<size_t>(std::real(work[0])));
400 TasBLAS::gels(effective_trans, n, m, nrhs, A, n, b, N, work.data(),
static_cast<int>(work.size()));
409 template<
typename scalar_type>
410 inline void factorizeLQ(
int rows,
int cols, scalar_type A[], std::vector<scalar_type> &T){
412 std::vector<scalar_type> work(1);
413 geql(rows, cols, A, rows, T.data(), -1, work.data(), -1);
414 T.resize(
static_cast<size_t>(std::real(T[0])));
415 work.resize(
static_cast<size_t>(std::real(work[0])));
416 geql(rows, cols, A, rows, T.data(),
static_cast<int>(T.size()), work.data(),
static_cast<int>(work.size()));
425 template<
typename scalar_type>
426 inline void multiplyQ(
int M,
int N,
int K, scalar_type
const A[], std::vector<scalar_type>
const &T, scalar_type C[]){
427 std::vector<scalar_type> work(1);
428 gemlq(
'R',
get_trans(
static_cast<scalar_type
>(0.0)), M, N, K, A, K, T.data(),
static_cast<int>(T.size()), C, M, work.data(), -1);
429 work.resize(
static_cast<int>(std::real(work[0])));
430 gemlq(
'R',
get_trans(
static_cast<scalar_type
>(0.0)), M, N, K, A, K, T.data(),
static_cast<int>(T.size()), C, M, work.data(),
static_cast<int>(work.size()));
436 template<
typename scalar_type>
437 void solveLSmulti(
int n,
int m, scalar_type A[],
int nrhs, scalar_type B[]){
441 #ifdef Tasmanian_BLAS_HAS_ZGELQ
442 std::vector<scalar_type> T;
445 TasBLAS::trsm(
'R',
'L',
'N',
'N', nrhs, m, 1.0, A, m, B, nrhs);
void vswap(int N, double x[], int incx, double y[], int incy)
BLAS dswap.
Definition: tsgBlasWrappers.hpp:110
void conj_matrix(int, int, double[])
Conjugates a matrix, no op in the real case.
Definition: tsgBlasWrappers.hpp:367
void gels(char trans, int M, int N, int nrhs, double A[], int lda, double B[], int ldb, double work[], int lwork)
LAPACK dgels.
Definition: tsgBlasWrappers.hpp:180
double norm2(int N, double const x[], int incx)
BLAS dnrm2.
Definition: tsgBlasWrappers.hpp:103
void steqr(char compz, int N, double D[], double E[], double Z[], int ldz, double work[])
LAPACK dsteqr.
Definition: tsgBlasWrappers.hpp:238
void trsm(char side, char uplo, char trans, char diag, int M, int N, double alpha, double const A[], int lda, double B[], int ldb)
BLAS dtrsm.
Definition: tsgBlasWrappers.hpp:147
void gemm(char transa, char transb, int M, int N, int K, double alpha, double const A[], int lda, double const B[], int ldb, double beta, double C[], int ldc)
BLAS gemm.
Definition: tsgBlasWrappers.hpp:139
void trsv(char uplo, char trans, char diag, int N, double const A[], int lda, double x[], int incx)
BLAS dtrsv.
Definition: tsgBlasWrappers.hpp:132
void getrs(char trans, int N, int nrhs, double const A[], int lda, int const ipiv[], double B[], int ldb)
LAPACK dgetrs.
Definition: tsgBlasWrappers.hpp:171
void factorizeLQ(int rows, int cols, scalar_type A[], std::vector< scalar_type > &T)
Compute the LQ factorization of the matrix A.
Definition: tsgBlasWrappers.hpp:410
void stebz(char range, char order, int N, double vl, double vu, int il, int iu, double abstol, double D[], double E[], int &M, int &nsplit, double W[], int iblock[], int isplit[], double work[], int iwork[])
LAPACK dstebz.
Definition: tsgBlasWrappers.hpp:208
void getrf(int M, int N, double A[], int lda, int ipiv[])
LAPACK dgetrf.
Definition: tsgBlasWrappers.hpp:162
void denseMultiply(int M, int N, int K, T alpha, const T A[], const T B[], T beta, T C[])
Combination of BLAS gemm and gemv.
Definition: tsgBlasWrappers.hpp:352
void sterf(int N, double D[], double E[])
LAPACK dsterf.
Definition: tsgBlasWrappers.hpp:257
void gemv(char trans, int M, int N, double alpha, double const A[], int lda, double const x[], int incx, double beta, double y[], int incy)
BLAS dgemv.
Definition: tsgBlasWrappers.hpp:124
void scal(int N, double alpha, double x[], int incx)
BLAS dscal.
Definition: tsgBlasWrappers.hpp:117
void solveLSmulti(int n, int m, scalar_type A[], int nrhs, scalar_type B[])
Solves the least-squares assuming row-major format, see TasmanianDenseSolver::solvesLeastSquares()
Definition: tsgBlasWrappers.hpp:437
void solveLS(char trans, int N, int M, scalar_type A[], scalar_type b[], int nrhs=1)
Solves the over-determined least squares problem with single right-hand-side.
Definition: tsgBlasWrappers.hpp:392
T norm2_2(int N, T const x[])
Returns the square of the norm of the vector.
Definition: tsgBlasWrappers.hpp:340
constexpr char get_trans(double)
Returns the transpose symbol, 'T' in the real case.
Definition: tsgBlasWrappers.hpp:379
void multiplyQ(int M, int N, int K, scalar_type const A[], std::vector< scalar_type > const &T, scalar_type C[])
Multiplies C by the Q factor computed with factorizeLQ.
Definition: tsgBlasWrappers.hpp:426
Wrappers for BLAS and LAPACK methods (hidden internal namespace).
Definition: tsgBlasWrappers.hpp:98
void transpose(long long M, long long N, scalar_type const A[], scalar_type B[])
Constructs the transpose of an M by N matrix A in column major format, result is stored in B (impleme...
Omnipresent enumerate types.