Spaces:
Running
Running
// Forward declare TI | |
namespace at { | |
class Tensor; | |
struct TensorIterator; | |
namespace native { | |
enum class TransposeType; | |
} | |
} | |
namespace at::native { | |
enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; | |
// Define per-batch functions to be used in the implementation of batched | |
// linear algebra operations | |
template <class scalar_t> | |
void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); | |
template <class scalar_t> | |
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); | |
template <class scalar_t, class value_t=scalar_t> | |
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); | |
template <class scalar_t> | |
void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); | |
template <class scalar_t> | |
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); | |
template <class scalar_t> | |
void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); | |
template <class scalar_t, class value_t = scalar_t> | |
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); | |
template <class scalar_t> | |
void lapackGels(char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info); | |
template <class scalar_t, class value_t = scalar_t> | |
void lapackGelsd(int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
value_t *s, value_t rcond, int *rank, | |
scalar_t* work, int lwork, | |
value_t *rwork, int* iwork, int *info); | |
template <class scalar_t, class value_t = scalar_t> | |
void lapackGelsy(int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
int *jpvt, value_t rcond, int *rank, | |
scalar_t *work, int lwork, value_t* rwork, int *info); | |
template <class scalar_t, class value_t = scalar_t> | |
void lapackGelss(int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
value_t *s, value_t rcond, int *rank, | |
scalar_t *work, int lwork, | |
value_t *rwork, int *info); | |
template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t> | |
struct lapackLstsq_impl; | |
template <class scalar_t, class value_t> | |
struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> { | |
static void call( | |
char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info, // Gels flavor | |
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor | |
value_t *s, // Gelss flavor | |
int *iwork // Gelsd flavor | |
) { | |
lapackGels<scalar_t>( | |
trans, m, n, nrhs, | |
a, lda, b, ldb, | |
work, lwork, info); | |
} | |
}; | |
template <class scalar_t, class value_t> | |
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> { | |
static void call( | |
char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info, // Gels flavor | |
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor | |
value_t *s, // Gelss flavor | |
int *iwork // Gelsd flavor | |
) { | |
lapackGelsy<scalar_t, value_t>( | |
m, n, nrhs, | |
a, lda, b, ldb, | |
jpvt, rcond, rank, | |
work, lwork, rwork, info); | |
} | |
}; | |
template <class scalar_t, class value_t> | |
struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> { | |
static void call( | |
char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info, // Gels flavor | |
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor | |
value_t *s, // Gelss flavor | |
int *iwork // Gelsd flavor | |
) { | |
lapackGelsd<scalar_t, value_t>( | |
m, n, nrhs, | |
a, lda, b, ldb, | |
s, rcond, rank, | |
work, lwork, | |
rwork, iwork, info); | |
} | |
}; | |
template <class scalar_t, class value_t> | |
struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> { | |
static void call( | |
char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info, // Gels flavor | |
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor | |
value_t *s, // Gelss flavor | |
int *iwork // Gelsd flavor | |
) { | |
lapackGelss<scalar_t, value_t>( | |
m, n, nrhs, | |
a, lda, b, ldb, | |
s, rcond, rank, | |
work, lwork, | |
rwork, info); | |
} | |
}; | |
template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t> | |
void lapackLstsq( | |
char trans, int m, int n, int nrhs, | |
scalar_t *a, int lda, scalar_t *b, int ldb, | |
scalar_t *work, int lwork, int *info, // Gels flavor | |
int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor | |
value_t *s, // Gelss flavor | |
int *iwork // Gelsd flavor | |
) { | |
lapackLstsq_impl<driver_type, scalar_t, value_t>::call( | |
trans, m, n, nrhs, | |
a, lda, b, ldb, | |
work, lwork, info, | |
jpvt, rcond, rank, rwork, | |
s, | |
iwork); | |
} | |
template <class scalar_t> | |
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); | |
template <class scalar_t> | |
void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); | |
template <class scalar_t> | |
void lapackLdlHermitian( | |
char uplo, | |
int n, | |
scalar_t* a, | |
int lda, | |
int* ipiv, | |
scalar_t* work, | |
int lwork, | |
int* info); | |
template <class scalar_t> | |
void lapackLdlSymmetric( | |
char uplo, | |
int n, | |
scalar_t* a, | |
int lda, | |
int* ipiv, | |
scalar_t* work, | |
int lwork, | |
int* info); | |
template <class scalar_t> | |
void lapackLdlSolveHermitian( | |
char uplo, | |
int n, | |
int nrhs, | |
scalar_t* a, | |
int lda, | |
int* ipiv, | |
scalar_t* b, | |
int ldb, | |
int* info); | |
template <class scalar_t> | |
void lapackLdlSolveSymmetric( | |
char uplo, | |
int n, | |
int nrhs, | |
scalar_t* a, | |
int lda, | |
int* ipiv, | |
scalar_t* b, | |
int ldb, | |
int* info); | |
template<class scalar_t, class value_t=scalar_t> | |
void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); | |
template <class scalar_t> | |
void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); | |
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); | |
DECLARE_DISPATCH(cholesky_fn, cholesky_stub); | |
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/); | |
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub); | |
using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); | |
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub); | |
using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); | |
DECLARE_DISPATCH(geqrf_fn, geqrf_stub); | |
using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/); | |
DECLARE_DISPATCH(orgqr_fn, orgqr_stub); | |
using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/); | |
DECLARE_DISPATCH(ormqr_fn, ormqr_stub); | |
using linalg_eigh_fn = void (*)( | |
const Tensor& /*eigenvalues*/, | |
const Tensor& /*eigenvectors*/, | |
const Tensor& /*infos*/, | |
bool /*upper*/, | |
bool /*compute_eigenvectors*/); | |
DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub); | |
using lstsq_fn = void (*)( | |
const Tensor& /*a*/, | |
Tensor& /*b*/, | |
Tensor& /*rank*/, | |
Tensor& /*singular_values*/, | |
Tensor& /*infos*/, | |
double /*rcond*/, | |
std::string /*driver_name*/); | |
DECLARE_DISPATCH(lstsq_fn, lstsq_stub); | |
using triangular_solve_fn = void (*)( | |
const Tensor& /*A*/, | |
const Tensor& /*B*/, | |
bool /*left*/, | |
bool /*upper*/, | |
TransposeType /*transpose*/, | |
bool /*unitriangular*/); | |
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); | |
using lu_factor_fn = void (*)( | |
const Tensor& /*input*/, | |
const Tensor& /*pivots*/, | |
const Tensor& /*infos*/, | |
bool /*compute_pivots*/); | |
DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); | |
using unpack_pivots_fn = void(*)( | |
TensorIterator& iter, | |
const int64_t dim_size, | |
const int64_t max_pivot); | |
DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub); | |
using lu_solve_fn = void (*)( | |
const Tensor& /*LU*/, | |
const Tensor& /*pivots*/, | |
const Tensor& /*B*/, | |
TransposeType /*trans*/); | |
DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub); | |
using ldl_factor_fn = void (*)( | |
const Tensor& /*LD*/, | |
const Tensor& /*pivots*/, | |
const Tensor& /*info*/, | |
bool /*upper*/, | |
bool /*hermitian*/); | |
DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub); | |
using svd_fn = void (*)( | |
const Tensor& /*A*/, | |
const bool /*full_matrices*/, | |
const bool /*compute_uv*/, | |
const c10::optional<c10::string_view>& /*driver*/, | |
const Tensor& /*U*/, | |
const Tensor& /*S*/, | |
const Tensor& /*Vh*/, | |
const Tensor& /*info*/); | |
DECLARE_DISPATCH(svd_fn, svd_stub); | |
using ldl_solve_fn = void (*)( | |
const Tensor& /*LD*/, | |
const Tensor& /*pivots*/, | |
const Tensor& /*result*/, | |
bool /*upper*/, | |
bool /*hermitian*/); | |
DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub); | |
} // namespace at::native | |