Doxygen 1.9.1
Toolkit for Adaptive Stochastic Modeling and Non-Intrusive ApproximatioN: Tasmanian v8.2 (development)
tsgBlasWrappers.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2017, Miroslav Stoyanov
3  *
4  * This file is part of
5  * Toolkit for Adaptive Stochastic Modeling And Non-Intrusive ApproximatioN: TASMANIAN
6  *
7  * Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
12  * and the following disclaimer in the documentation and/or other materials provided with the distribution.
13  *
14  * 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse
15  * or promote products derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
18  * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
19  * IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
20  * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
21  * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  * UT-BATTELLE, LLC AND THE UNITED STATES GOVERNMENT MAKE NO REPRESENTATIONS AND DISCLAIM ALL WARRANTIES, BOTH EXPRESSED AND IMPLIED.
25  * THERE ARE NO EXPRESS OR IMPLIED WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, OR THAT THE USE OF THE SOFTWARE WILL NOT INFRINGE ANY PATENT,
26  * COPYRIGHT, TRADEMARK, OR OTHER PROPRIETARY RIGHTS, OR THAT THE SOFTWARE WILL ACCOMPLISH THE INTENDED RESULTS OR THAT THE SOFTWARE OR ITS USE WILL NOT RESULT IN INJURY OR DAMAGE.
27  * THE USER ASSUMES RESPONSIBILITY FOR ALL LIABILITIES, PENALTIES, FINES, CLAIMS, CAUSES OF ACTION, AND COSTS AND EXPENSES, CAUSED BY, RESULTING FROM OR ARISING OUT OF,
28  * IN WHOLE OR IN PART THE USE, STORAGE OR DISPOSAL OF THE SOFTWARE.
29  */
30 
31 #ifndef __TASMANIAN_BLAS_WRAPPERS_HPP
32 #define __TASMANIAN_BLAS_WRAPPERS_HPP
33 
34 #include "tsgEnumerates.hpp"
35 
48 #ifndef __TASMANIAN_DOXYGEN_SKIP
49 extern "C"{
50 // Skip the definitions from Doxygen, this serves as a mock-up header for the BLAS API.
51 // BLAS level 1
52 double dnrm2_(const int *N, const double *x, const int *incx);
53 void dswap_(const int *N, double *x, const int *incx, double *y, const int *incy);
54 void dscal_(const int *N, const double *alpha, const double *x, const int *incx);
55 // BLAS level 2
56 void dgemv_(const char *transa, const int *M, const int *N, const double *alpha, const double *A, const int *lda,
57  const double *x, const int *incx, const double *beta, const double *y, const int *incy);
58 void dtrsv_(const char *uplo, const char *trans, const char *diag, const int *N, const double *A, const int *lda,
59  double *x, const int *incx);
60 // BLAS level 3
61 void dgemm_(const char* transa, const char* transb, const int *m, const int *n, const int *k, const double *alpha,
62  const double *A, const int *lda, const double *B, const int *ldb, const double *beta, const double *C, const int *ldc);
63 void dtrsm_(const char *side, const char *uplo, const char *trans, const char *diag, const int *M, const int *N,
64  const double *alpha, const double *A, const int *lda, double *B, const int *ldb);
65 void ztrsm_(const char *side, const char *uplo, const char *trans, const char *diag, const int *M, const int *N,
66  const std::complex<double> *alpha, const std::complex<double> *A, const int *lda, std::complex<double> *B, const int *ldb);
67 // LAPACK solvers
68 // General PLU factorize/solve
69 void dgetrf_(const int *M, const int *N, double *A, const int *lda, int *ipiv, int *info);
70 void dgetrs_(const char *trans, const int *N, const int *nrhs, const double *A, const int *lda, const int *ipiv, double *B, const int *ldb, int *info);
71 // General least-squares solve
72 void dgels_(const char *trans, const int *M, const int *N, const int *nrhs, double *A, const int *lda,
73  double *B, const int *ldb, double *work, int *lwork, int *info);
74 void zgels_(const char *trans, const int *M, const int *N, const int *nrhs, std::complex<double> *A, const int *lda,
75  std::complex<double> *B, const int *ldb, std::complex<double> *work, int *lwork, int *info);
76 // Symmetric tridiagonal eigenvalue compute
77 void dstebz_(const char *range, const char *order, const int *N, const double *vl, const double *vu, const int *il, const int *iu, const double *abstol,
78  const double D[], const double E[], int *M, int *nsplit, double W[], int iblock[], int isplit[], double work[], int iwork[], int *info);
79 void dsteqr_(const char *compz, const int *N, double D[], double E[], double Z[], const int *ldz, double work[], int *info);
80 void dsterf_(const int *N, double D[], double E[], int *info);
81 // General LQ-factorize and multiply by Q
82 #ifdef Tasmanian_BLAS_HAS_ZGELQ
83 void dgelq_(const int *M, const int *N, double *A, const int *lda, double *T, int const *Tsize, double *work, int const *lwork, int *info);
84 void dgemlq_(const char *side, const char *trans, const int *M, const int *N, const int *K, double const *A, int const *lda,
85  double const *T, int const *Tsize, double C[], int const *ldc, double *work, int const *lwork, int *info);
86 void zgelq_(const int *M, const int *N, std::complex<double> *A, const int *lda, std::complex<double> *T, int const *Tsize,
87  std::complex<double> *work, int const *lwork, int *info);
88 void zgemlq_(const char *side, const char *trans, const int *M, const int *N, const int *K, std::complex<double> const *A, int const *lda,
89  std::complex<double> const *T, int const *Tsize, std::complex<double> C[], int const *ldc, std::complex<double> *work, int const *lwork, int *info);
90 #endif
91 }
92 #endif
93 
98 namespace TasBLAS{
103  inline double norm2(int N, double const x[], int incx){
104  return dnrm2_(&N, x, &incx);
105  }
110  inline void vswap(int N, double x[], int incx, double y[], int incy){
111  dswap_(&N, x, &incx, y, &incy);
112  }
117  inline void scal(int N, double alpha, double x[], int incx){
118  dscal_(&N, &alpha, x, &incx);
119  }
124  inline void gemv(char trans, int M, int N, double alpha, double const A[], int lda, double const x[], int incx,
125  double beta, double y[], int incy){
126  dgemv_(&trans, &M, &N, &alpha, A, &lda, x, &incx, &beta, y, &incy);
127  }
132  inline void trsv(char uplo, char trans, char diag, int N, double const A[], int lda, double x[], int incx){
133  dtrsv_(&uplo, &trans, &diag, &N, A, &lda, x, &incx);
134  }
139  inline void gemm(char transa, char transb, int M, int N, int K, double alpha, double const A[], int lda, double const B[], int ldb,
140  double beta, double C[], int ldc){
141  dgemm_(&transa, &transb, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
142  }
147  inline void trsm(char side, char uplo, char trans, char diag, int M, int N, double alpha, double const A[], int lda, double B[], int ldb){
148  dtrsm_(&side, &uplo, &trans, &diag, &M, &N, &alpha, A, &lda, B, &ldb);
149  }
154  inline void trsm(char side, char uplo, char trans, char diag, int M, int N, std::complex<double> alpha,
155  std::complex<double> const A[], int lda, std::complex<double> B[], int ldb){
156  ztrsm_(&side, &uplo, &trans, &diag, &M, &N, &alpha, A, &lda, B, &ldb);
157  }
162  inline void getrf(int M, int N, double A[], int lda, int ipiv[]){
163  int info = 0;
164  dgetrf_(&M, &N, A, &lda, ipiv, &info);
165  if (info != 0) throw std::runtime_error(std::string("Lapack dgetrf_ exited with code: ") + std::to_string(info));
166  }
171  inline void getrs(char trans, int N, int nrhs, double const A[], int lda, int const ipiv[], double B[], int ldb){
172  int info = 0;
173  dgetrs_(&trans, &N, &nrhs, A, &lda, ipiv, B, &ldb, &info);
174  if (info != 0) throw std::runtime_error(std::string("Lapack dgetrs_ exited with code: ") + std::to_string(info));
175  }
180  inline void gels(char trans, int M, int N, int nrhs, double A[], int lda, double B[], int ldb, double work[], int lwork){
181  int info = 0;
182  dgels_(&trans, &M, &N, &nrhs, A, &lda, B, &ldb, work, &lwork, &info);
183  if (info != 0){
184  if (lwork > 0)
185  throw std::runtime_error(std::string("Lapack dgels_ solve-stage exited with code: ") + std::to_string(info));
186  else
187  throw std::runtime_error(std::string("Lapack dgels_ infer-worksize-stage exited with code: ") + std::to_string(info));
188  }
189  }
194  inline void gels(char trans, int M, int N, int nrhs, std::complex<double> A[], int lda, std::complex<double> B[], int ldb, std::complex<double> work[], int lwork){
195  int info = 0;
196  zgels_(&trans, &M, &N, &nrhs, A, &lda, B, &ldb, work, &lwork, &info);
197  if (info != 0){
198  if (lwork > 0)
199  throw std::runtime_error(std::string("Lapack zgels_ solve-stage exited with code: ") + std::to_string(info));
200  else
201  throw std::runtime_error(std::string("Lapack zgels_ infer-worksize-stage exited with code: ") + std::to_string(info));
202  }
203  }
208  inline void stebz(char range, char order, int N, double vl, double vu, int il, int iu, double abstol, double D[], double E[],
209  int& M, int& nsplit, double W[], int iblock[], int isplit[], double work[], int iwork[]) {
210  int info = 0;
211  dstebz_(&range, &order, &N, &vl, &vu, &il, &iu, &abstol, D, E, &M, &nsplit, W, iblock, isplit, work, iwork, &info);
212  if (info != 0) {
213  if (info <= 3) {
214  throw std::runtime_error(
215  std::string(
216  "Lapack dstebz_ failed to converge for some eigenvalues and exited with code: ") +
217  std::to_string(info));
218  } else if (info == 4) {
219  throw std::runtime_error(
220  std::string("Lapack dstebz_ used a Gershgorin interval that was too small and exited with code: ") +
221  std::to_string(info));
222  } else if (info > 4) {
223  throw std::runtime_error(
224  std::string("Lapack dstebz_ failed and exited with code: ") +
225  std::to_string(info));
226  } else {
227  throw std::runtime_error(
228  std::string(
229  "Lapack dstebz_ had an illegal value at argument number: ") +
230  std::to_string(-info));
231  }
232  }
233  }
238  inline void steqr(char compz, int N, double D[], double E[], double Z[], int ldz, double work[]) {
239  int info = 0;
240  dsteqr_(&compz, &N, D, E, Z, &ldz, work, &info);
241  if (info != 0) {
242  if (info > 0) {
243  throw std::runtime_error(
244  std::string("Lapack dsteqr_ failed to converge for some eigenvalues and exited with code: ") +
245  std::to_string(info));
246  } else {
247  throw std::runtime_error(
248  std::string("Lapack dsteqr_ had an illegal value at argument number: ") +
249  std::to_string(-info));
250  }
251  }
252  }
257  inline void sterf(int N, double D[], double E[]) {
258  int info = 0;
259  dsterf_(&N, D, E, &info);
260  if (info != 0) {
261  if (info > 0) {
262  throw std::runtime_error(
263  std::string("Lapack dsteqr_ failed to converge for some eigenvalues and exited with code: ") +
264  std::to_string(info));
265  } else {
266  throw std::runtime_error(
267  std::string("Lapack dsteqr_ had an illegal value at argument number: ") +
268  std::to_string(-info));
269  }
270  }
271  }
272 #ifdef Tasmanian_BLAS_HAS_ZGELQ
277  inline void geql(int M, int N, double A[], int lda, double T[], int Tsize, double work[], int lwork){
278  int info = 0;
279  dgelq_(&M, &N, A, &lda, T, &Tsize, work, &lwork, &info);
280  if (info != 0){
281  if (lwork > 0)
282  throw std::runtime_error(std::string("Lapack dgeql_ factorize-stage exited with code: ") + std::to_string(info));
283  else
284  throw std::runtime_error(std::string("Lapack dgeql_ infer-worksize-stage exited with code: ") + std::to_string(info));
285  }
286  }
291  inline void gemlq(char side, char trans, int M, int N, int K, double const A[], int lda, double const T[], int Tsize,
292  double C[], int ldc, double work[], int lwork){
293  int info = 0;
294  dgemlq_(&side, &trans, &M, &N, &K, A, &lda, T, &Tsize, C, &ldc, work, &lwork, &info);
295  if (info != 0){
296  if (lwork > 0)
297  throw std::runtime_error(std::string("Lapack dgemlq_ compute-stage exited with code: ") + std::to_string(info));
298  else
299  throw std::runtime_error(std::string("Lapack dgemlq_ infer-worksize-stage exited with code: ") + std::to_string(info));
300  }
301  }
306  inline void geql(int M, int N, std::complex<double> A[], int lda, std::complex<double> T[], int Tsize, std::complex<double> work[], int lwork){
307  int info = 0;
308  zgelq_(&M, &N, A, &lda, T, &Tsize, work, &lwork, &info);
309  if (info != 0){
310  if (lwork > 0)
311  throw std::runtime_error(std::string("Lapack zgeql_ factorize-stage exited with code: ") + std::to_string(info));
312  else
313  throw std::runtime_error(std::string("Lapack zgeql_ infer-worksize-stage exited with code: ") + std::to_string(info));
314  }
315  }
320  inline void gemlq(char side, char trans, int M, int N, int K, std::complex<double> const A[], int lda, std::complex<double> const T[], int Tsize,
321  std::complex<double> C[], int ldc, std::complex<double> work[], int lwork){
322  int info = 0;
323  zgemlq_(&side, &trans, &M, &N, &K, A, &lda, T, &Tsize, C, &ldc, work, &lwork, &info);
324  if (info != 0){
325  if (lwork > 0)
326  throw std::runtime_error(std::string("Lapack zgemlq_ compute-stage exited with code: ") + std::to_string(info));
327  else
328  throw std::runtime_error(std::string("Lapack zgemlq_ infer-worksize-stage exited with code: ") + std::to_string(info));
329  }
330  }
331  #endif
332 
333  // higher-level methods building on top of one or more BLAS/LAPACK Methods
334 
339  template<typename T>
340  inline T norm2_2(int N, T const x[]){
341  T nrm = norm2(N, x, 1);
342  return nrm * nrm;
343  }
351  template<typename T>
352  inline void denseMultiply(int M, int N, int K, T alpha, const T A[], const T B[], T beta, T C[]){
353  if (M > 1){
354  if (N > 1){ // matrix mode
355  gemm('N', 'N', M, N, K, alpha, A, M, B, K, beta, C, M);
356  }else{ // matrix vector, A * v = C
357  gemv('N', M, K, alpha, A, M, B, 1, beta, C, 1);
358  }
359  }else{ // matrix vector B^T * v = C
360  gemv('T', K, N, alpha, B, K, A, 1, beta, C, 1);
361  }
362  }
367  inline void conj_matrix(int, int, double[]){}
372  inline void conj_matrix(int N, int M, std::complex<double> A[]){
373  for(size_t i=0; i<static_cast<size_t>(N) * static_cast<size_t>(M); i++) A[i] = std::conj(A[i]);
374  }
379  constexpr inline char get_trans(double){ return 'T'; }
384  constexpr inline char get_trans(std::complex<double>){ return 'C'; }
391  template<typename scalar_type>
392  inline void solveLS(char trans, int N, int M, scalar_type A[], scalar_type b[], int nrhs = 1){
393  std::vector<scalar_type> work(1);
394  int n = (trans == 'N') ? N : M;
395  int m = (trans == 'N') ? M : N;
396  char effective_trans = (trans == 'N') ? trans : get_trans(static_cast<scalar_type>(0.0));
397  conj_matrix(N, M, A); // does nothing in the real case, computes the conjugate in the complex one
398  TasBLAS::gels(effective_trans, n, m, nrhs, A, n, b, N, work.data(), -1);
399  work.resize(static_cast<size_t>(std::real(work[0])));
400  TasBLAS::gels(effective_trans, n, m, nrhs, A, n, b, N, work.data(), static_cast<int>(work.size()));
401  }
409  template<typename scalar_type>
410  inline void factorizeLQ(int rows, int cols, scalar_type A[], std::vector<scalar_type> &T){
411  T.resize(5);
412  std::vector<scalar_type> work(1);
413  geql(rows, cols, A, rows, T.data(), -1, work.data(), -1);
414  T.resize(static_cast<size_t>(std::real(T[0])));
415  work.resize(static_cast<size_t>(std::real(work[0])));
416  geql(rows, cols, A, rows, T.data(), static_cast<int>(T.size()), work.data(), static_cast<int>(work.size()));
417  }
425  template<typename scalar_type>
426  inline void multiplyQ(int M, int N, int K, scalar_type const A[], std::vector<scalar_type> const &T, scalar_type C[]){
427  std::vector<scalar_type> work(1);
428  gemlq('R', get_trans(static_cast<scalar_type>(0.0)), M, N, K, A, K, T.data(), static_cast<int>(T.size()), C, M, work.data(), -1);
429  work.resize(static_cast<int>(std::real(work[0])));
430  gemlq('R', get_trans(static_cast<scalar_type>(0.0)), M, N, K, A, K, T.data(), static_cast<int>(T.size()), C, M, work.data(), static_cast<int>(work.size()));
431  }
436  template<typename scalar_type>
437  void solveLSmulti(int n, int m, scalar_type A[], int nrhs, scalar_type B[]){
438  if (nrhs == 1){
439  TasBLAS::solveLS('T', n, m, A, B);
440  }else{
441  #ifdef Tasmanian_BLAS_HAS_ZGELQ
442  std::vector<scalar_type> T;
443  TasBLAS::factorizeLQ(m, n, A, T);
444  TasBLAS::multiplyQ(nrhs, n, m, A, T, B);
445  TasBLAS::trsm('R', 'L', 'N', 'N', nrhs, m, 1.0, A, m, B, nrhs);
446  #else
447  auto Bcols = TasGrid::Utils::transpose(nrhs, n, B);
448  TasBLAS::solveLS('T', n, m, A, Bcols.data(), nrhs);
449  TasGrid::Utils::transpose(n, nrhs, Bcols.data(), B);
450  #endif
451  }
452  }
453 }
454 
455 #endif
void vswap(int N, double x[], int incx, double y[], int incy)
BLAS dswap.
Definition: tsgBlasWrappers.hpp:110
void conj_matrix(int, int, double[])
Conjugates a matrix, no op in the real case.
Definition: tsgBlasWrappers.hpp:367
void gels(char trans, int M, int N, int nrhs, double A[], int lda, double B[], int ldb, double work[], int lwork)
LAPACK dgels.
Definition: tsgBlasWrappers.hpp:180
double norm2(int N, double const x[], int incx)
BLAS dnrm2.
Definition: tsgBlasWrappers.hpp:103
void steqr(char compz, int N, double D[], double E[], double Z[], int ldz, double work[])
LAPACK dsteqr.
Definition: tsgBlasWrappers.hpp:238
void trsm(char side, char uplo, char trans, char diag, int M, int N, double alpha, double const A[], int lda, double B[], int ldb)
BLAS dtrsm.
Definition: tsgBlasWrappers.hpp:147
void gemm(char transa, char transb, int M, int N, int K, double alpha, double const A[], int lda, double const B[], int ldb, double beta, double C[], int ldc)
BLAS gemm.
Definition: tsgBlasWrappers.hpp:139
void trsv(char uplo, char trans, char diag, int N, double const A[], int lda, double x[], int incx)
BLAS dtrsv.
Definition: tsgBlasWrappers.hpp:132
void getrs(char trans, int N, int nrhs, double const A[], int lda, int const ipiv[], double B[], int ldb)
LAPACK dgetrs.
Definition: tsgBlasWrappers.hpp:171
void factorizeLQ(int rows, int cols, scalar_type A[], std::vector< scalar_type > &T)
Compute the LQ factorization of the matrix A.
Definition: tsgBlasWrappers.hpp:410
void stebz(char range, char order, int N, double vl, double vu, int il, int iu, double abstol, double D[], double E[], int &M, int &nsplit, double W[], int iblock[], int isplit[], double work[], int iwork[])
LAPACK dstebz.
Definition: tsgBlasWrappers.hpp:208
void getrf(int M, int N, double A[], int lda, int ipiv[])
LAPACK dgetrf.
Definition: tsgBlasWrappers.hpp:162
void denseMultiply(int M, int N, int K, T alpha, const T A[], const T B[], T beta, T C[])
Combination of BLAS gemm and gemv.
Definition: tsgBlasWrappers.hpp:352
void sterf(int N, double D[], double E[])
LAPACK dsterf.
Definition: tsgBlasWrappers.hpp:257
void gemv(char trans, int M, int N, double alpha, double const A[], int lda, double const x[], int incx, double beta, double y[], int incy)
BLAS dgemv.
Definition: tsgBlasWrappers.hpp:124
void scal(int N, double alpha, double x[], int incx)
BLAS dscal.
Definition: tsgBlasWrappers.hpp:117
void solveLSmulti(int n, int m, scalar_type A[], int nrhs, scalar_type B[])
Solves the least-squares assuming row-major format, see TasmanianDenseSolver::solvesLeastSquares()
Definition: tsgBlasWrappers.hpp:437
void solveLS(char trans, int N, int M, scalar_type A[], scalar_type b[], int nrhs=1)
Solves the over-determined least squares problem with single right-hand-side.
Definition: tsgBlasWrappers.hpp:392
T norm2_2(int N, T const x[])
Returns the square of the norm of the vector.
Definition: tsgBlasWrappers.hpp:340
constexpr char get_trans(double)
Returns the transpose symbol, 'T' in the real case.
Definition: tsgBlasWrappers.hpp:379
void multiplyQ(int M, int N, int K, scalar_type const A[], std::vector< scalar_type > const &T, scalar_type C[])
Multiplies C by the Q factor computed with factorizeLQ.
Definition: tsgBlasWrappers.hpp:426
Wrappers for BLAS and LAPACK methods (hidden internal namespace).
Definition: tsgBlasWrappers.hpp:98
void transpose(long long M, long long N, scalar_type const A[], scalar_type B[])
Constructs the transpose of an M by N matrix A in column major format, result is stored in B (impleme...
Omnipresent enumerate types.