Spaces:
Running
Running
namespace at::native::cpublas { | |
namespace internal { | |
void normalize_last_dims( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
int64_t *lda, int64_t *ldb, int64_t *ldc); | |
} // namespace internal | |
using gemm_fn = void(*)( | |
at::ScalarType type, | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
const Scalar& alpha, | |
const void *a, int64_t lda, | |
const void *b, int64_t ldb, | |
const Scalar& beta, | |
void *c, int64_t ldc); | |
DECLARE_DISPATCH(gemm_fn, gemm_stub); | |
template <typename scalar_t> | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
at::opmath_type<scalar_t> alpha, | |
const scalar_t *a, int64_t lda, | |
const scalar_t *b, int64_t ldb, | |
at::opmath_type<scalar_t> beta, | |
scalar_t *c, int64_t ldc) { | |
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | |
gemm_stub( | |
kCPU, c10::CppTypeToScalarType<scalar_t>::value, | |
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | |
} | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
double alpha, | |
const double *a, int64_t lda, | |
const double *b, int64_t ldb, | |
double beta, | |
double *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
float alpha, | |
const float *a, int64_t lda, | |
const float *b, int64_t ldb, | |
float beta, | |
float *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
float alpha, | |
const at::BFloat16 *a, int64_t lda, | |
const at::BFloat16 *b, int64_t ldb, | |
float beta, | |
at::BFloat16 *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
const float alpha, | |
const at::BFloat16 *a, int64_t lda, | |
const at::BFloat16 *b, int64_t ldb, | |
const float beta, | |
float *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
float alpha, | |
const at::Half *a, int64_t lda, | |
const at::Half *b, int64_t ldb, | |
float beta, | |
at::Half *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
const float alpha, | |
const at::Half *a, int64_t lda, | |
const at::Half *b, int64_t ldb, | |
const float beta, | |
float *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
c10::complex<double> alpha, | |
const c10::complex<double> *a, int64_t lda, | |
const c10::complex<double> *b, int64_t ldb, | |
c10::complex<double> beta, | |
c10::complex<double> *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
c10::complex<float> alpha, | |
const c10::complex<float> *a, int64_t lda, | |
const c10::complex<float> *b, int64_t ldb, | |
c10::complex<float> beta, | |
c10::complex<float> *c, int64_t ldc); | |
void gemm( | |
TransposeType transa, TransposeType transb, | |
int64_t m, int64_t n, int64_t k, | |
int64_t alpha, | |
const int64_t *a, int64_t lda, | |
const int64_t *b, int64_t ldb, | |
int64_t beta, | |
int64_t *c, int64_t ldc); | |
template <typename scalar_t> | |
void gemm_batched( | |
TransposeType transa, TransposeType transb, | |
int64_t batch_size, int64_t m, int64_t n, int64_t k, | |
scalar_t alpha, | |
const scalar_t * const *a, int64_t lda, | |
const scalar_t * const *b, int64_t ldb, | |
const scalar_t beta, | |
scalar_t * const *c, int64_t ldc); | |
template <typename scalar_t> | |
void gemm_batched_with_stride( | |
TransposeType transa, TransposeType transb, | |
int64_t batch_size, int64_t m, int64_t n, int64_t k, | |
scalar_t alpha, | |
const scalar_t *a, int64_t lda, int64_t batch_stride_a, | |
const scalar_t *b, int64_t ldb, int64_t batch_stride_b, | |
scalar_t beta, | |
scalar_t *c, int64_t ldc, int64_t batch_stride_c); | |
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy); | |
DECLARE_DISPATCH(axpy_fn, axpy_stub); | |
template<typename scalar_t> | |
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){ | |
if(n == 1) | |
{ | |
incx = 1; | |
incy = 1; | |
} | |
axpy_stub( | |
kCPU, c10::CppTypeToScalarType<scalar_t>::value, | |
n, a, x, incx, y, incy); | |
} | |
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy); | |
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy); | |
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy); | |
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy); | |
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy); | |
DECLARE_DISPATCH(copy_fn, copy_stub); | |
template<typename scalar_t> | |
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) { | |
if(n == 1) | |
{ | |
incx = 1; | |
incy = 1; | |
} | |
copy_stub( | |
kCPU, c10::CppTypeToScalarType<scalar_t>::value, | |
n, x, incx, y, incy); | |
} | |
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy); | |
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); | |
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy); | |
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy); | |
} // namespace at::native::cpublas | |