25 #ifndef MSHADOW_DOT_ENGINE_INL_H_ 26 #define MSHADOW_DOT_ENGINE_INL_H_ 33 #include "./cuda/tensor_gpu-inl.cuh" 34 #endif // #ifdef __CUDACC__ 45 template<
typename Device,
typename DType>
46 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
47 Stream<Device> *stream);
48 template<
typename DType>
51 for (
int i = 0; i < num; i++) {
52 dst[i] = src + i * stride;
57 template<
typename DType>
58 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
62 #endif // #ifdef __CUDACC__ 68 template<
typename SV,
typename Device,
int ddim,
int ldim,
69 int rdim,
bool ltrans,
bool rtrans,
typename DType>
77 template<
typename Device,
typename DType = default_real_t>
79 inline static bool GetT(
bool t) {
80 return t ? true :
false;
85 bool transa,
bool transb,
86 int m,
int n,
int k, DType alpha,
87 const DType *A,
int lda,
const DType *B,
int ldb,
88 DType beta, DType *C,
int ldc) {
89 LOG(FATAL) <<
"Not implmented!";
92 bool transa,
bool transb,
93 int m,
int n,
int k, DType alpha,
94 const DType *A,
int lda,
const DType *B,
int ldb,
95 DType beta, DType *C,
int ldc,
int batch_count,
97 LOG(FATAL) <<
"Not implmented!";
100 bool trans,
int m,
int n,
101 DType alpha,
const DType *A,
int lda,
102 const DType *X,
int incX,
103 DType beta, DType *Y,
int incY) {
104 LOG(FATAL) <<
"Not implmented!";
107 bool trans,
int m,
int n,
108 DType alpha,
const DType *A,
int lda,
109 const DType *X,
int incX,
110 DType beta, DType *Y,
int incY,
int batch_count) {
111 LOG(FATAL) <<
"Not implmented!";
114 int m,
int n, DType alpha,
115 const DType *X,
int incX,
116 const DType *Y,
int incY, DType *A,
int lda) {
117 LOG(FATAL) <<
"Not implmented!";
120 int m,
int n, DType alpha,
121 const DType *X,
int incX,
122 const DType *Y,
int incY, DType *A,
int lda,
int batch_count) {
123 LOG(FATAL) <<
"Not implmented!";
127 const DType* X,
int incX,
128 const DType* Y,
int incY,
130 LOG(FATAL) <<
"Not implmented!";
134 #if MSHADOW_STAND_ALONE 137 inline static bool GetT(
bool t) {
138 return t ? true :
false;
140 inline static void SetStream(
Stream<cpu> *stream) {
143 bool transa,
bool transb,
144 int m,
int n,
int k,
float alpha,
145 const float *A,
int lda,
const float *B,
int ldb,
146 float beta,
float *C,
int ldc) {
147 if (alpha == 1.0f && beta == 0.0f) {
148 bool transpose_left = transb;
149 bool transpose_right = transa;
153 if (!transpose_left && !transpose_right) {
155 }
else if (!transpose_left && transpose_right) {
157 }
else if (transpose_left && !transpose_right) {
160 LOG(FATAL) <<
"Not implmented!";
163 LOG(FATAL) <<
"Not implmented!";
166 inline static void batched_gemm(
Stream<cpu> *stream,
167 bool transa,
bool transb,
168 int m,
int n,
int k,
float alpha,
169 const float *A,
int lda,
const float *B,
int ldb,
170 float beta,
float *C,
int ldc,
int batch_count,
172 for (
int i = 0; i < batch_count; ++i) {
173 gemm(stream, transa, transb, m, n, k, alpha,
174 A + i * m * k, lda, B + i * k * n, ldb,
175 beta, C + i * m * n, ldc);
179 bool trans,
int m,
int n,
180 float alpha,
const float *A,
int lda,
181 const float *X,
int incX,
182 float beta,
float *Y,
int incY) {
183 LOG(FATAL) <<
"Not implmented!";
185 inline static void batched_gemv(
Stream<cpu> *stream,
186 bool trans,
int m,
int n,
187 float alpha,
const float *A,
int lda,
188 const float *X,
int incX,
189 float beta,
float *Y,
int incY,
int batch_count) {
190 LOG(FATAL) <<
"Not implmented!";
193 int m,
int n,
float alpha,
194 const float *X,
int incX,
195 const float *Y,
int incY,
float *A,
int lda) {
196 LOG(FATAL) <<
"Not implmented!";
198 inline static void batched_ger(
Stream<cpu> *stream,
199 int m,
int n,
float alpha,
200 const float *X,
int incX,
201 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
202 LOG(FATAL) <<
"Not implmented!";
206 const float* X,
int incX,
207 const float* Y,
int incY,
209 LOG(FATAL) <<
"Not implmented!";
215 inline static bool GetT(
bool t) {
216 return t ? true :
false;
218 inline static void SetStream(
Stream<cpu> *stream) {
221 bool transa,
bool transb,
222 int m,
int n,
int k,
double alpha,
223 const double *A,
int lda,
const double *B,
int ldb,
224 double beta,
double *C,
int ldc) {
225 if (alpha == 1.0f && beta == 0.0f) {
226 bool transpose_left = transb;
227 bool transpose_right = transa;
231 if (!transpose_left && !transpose_right) {
233 }
else if (!transpose_left && transpose_right) {
235 }
else if (transpose_left && !transpose_right) {
238 LOG(FATAL) <<
"Not implmented!";
241 LOG(FATAL) <<
"Not implmented!";
244 inline static void batched_gemm(
Stream<cpu> *stream,
245 bool transa,
bool transb,
246 int m,
int n,
int k,
double alpha,
247 const double *A,
int lda,
const double *B,
int ldb,
248 double beta,
double *C,
int ldc,
int batch_count,
249 double **workspace) {
250 for (
int i = 0; i < batch_count; ++i) {
251 gemm(stream, transa, transb, m, n, k, alpha,
252 A + i * m * k, lda, B + i * k * n, ldb,
253 beta, C + i * m * n, ldc);
257 bool trans,
int m,
int n,
258 double alpha,
const double *A,
int lda,
259 const double *X,
int incX,
260 double beta,
double *Y,
int incY) {
261 LOG(FATAL) <<
"Not implmented!";
263 inline static void batched_gemv(
Stream<cpu> *stream,
264 bool trans,
int m,
int n,
265 double alpha,
const double *A,
int lda,
266 const double *X,
int incX,
267 double beta,
double *Y,
int incY,
int batch_count) {
268 LOG(FATAL) <<
"Not implmented!";
271 int m,
int n,
double alpha,
272 const double *X,
int incX,
273 const double *Y,
int incY,
double *A,
int lda) {
274 LOG(FATAL) <<
"Not implmented!";
276 inline static void batched_ger(
Stream<cpu> *stream,
277 int m,
int n,
double alpha,
278 const double *X,
int incX,
279 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
280 LOG(FATAL) <<
"Not implmented!";
284 const double* X,
int incX,
285 const double* Y,
int incY,
287 LOG(FATAL) <<
"Not implmented!";
291 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) 294 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
295 return t ? CblasTrans : CblasNoTrans;
300 bool transa,
bool transb,
301 int m,
int n,
int k,
float alpha,
302 const float *A,
int lda,
const float *B,
int ldb,
303 float beta,
float *C,
int ldc) {
304 cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
305 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
308 bool transa,
bool transb,
309 int m,
int n,
int k,
float alpha,
310 const float *A,
int lda,
const float *B,
int ldb,
311 float beta,
float *C,
int ldc,
int batch_count,
313 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 315 const int GROUP_SIZE = 1;
316 MKL_INT p_m[GROUP_SIZE] = {m};
317 MKL_INT p_n[GROUP_SIZE] = {n};
318 MKL_INT p_k[GROUP_SIZE] = {k};
319 MKL_INT p_lda[GROUP_SIZE] = {lda};
320 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
321 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
323 float p_alpha[GROUP_SIZE] = {alpha};
324 float p_beta[GROUP_SIZE] = {beta};
326 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
327 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
329 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
330 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
331 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
333 std::vector<const float*> pp_A(batch_count,
nullptr);
334 std::vector<const float*> pp_B(batch_count,
nullptr);
335 std::vector<float*> pp_C(batch_count,
nullptr);
341 for (
int i = 0; i < batch_count; i++) {
342 pp_A[i] = A + i * m_k;
343 pp_B[i] = B + i * k_n;
344 pp_C[i] = C + i * m_n;
347 cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
348 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
349 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
351 for (
int i = 0; i < batch_count; ++i) {
352 gemm(stream, transa, transb, m, n, k, alpha,
353 A + i * m * k, lda, B + i * k * n, ldb,
354 beta, C + i * m * n, ldc);
359 bool trans,
int m,
int n,
360 float alpha,
const float *A,
int lda,
361 const float *X,
int incX,
362 float beta,
float *Y,
int incY) {
363 cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
364 A, lda, X, incX, beta, Y, incY);
367 bool trans,
int m,
int n,
368 float alpha,
const float *A,
int lda,
369 const float *X,
int incX,
370 float beta,
float *Y,
int incY,
int batch_count) {
371 for (
int i = 0; i < batch_count; ++i) {
372 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
373 X + i * (trans ? m : n) * incX, incX,
374 beta, Y + i * (trans ? n : m) * incY, incY);
378 int m,
int n,
float alpha,
379 const float *X,
int incX,
380 const float *Y,
int incY,
float *A,
int lda) {
381 cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
384 int m,
int n,
float alpha,
385 const float *X,
int incX,
386 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
387 for (
int i = 0; i < batch_count; ++i) {
388 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
389 A + i * lda * n, lda);
394 const float* X,
int incX,
395 const float* Y,
int incY,
397 *ret = cblas_sdot(n, X, incX, Y, incY);
403 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
404 return t ? CblasTrans : CblasNoTrans;
409 bool transa,
bool transb,
410 int m,
int n,
int k,
double alpha,
411 const double *A,
int lda,
const double *B,
int ldb,
412 double beta,
double *C,
int ldc) {
413 cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
414 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
417 bool transa,
bool transb,
418 int m,
int n,
int k,
double alpha,
419 const double *A,
int lda,
const double *B,
int ldb,
420 double beta,
double *C,
int ldc,
int batch_count,
421 double **workspace) {
422 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 424 const int GROUP_SIZE = 1;
425 MKL_INT p_m[GROUP_SIZE] = {m};
426 MKL_INT p_n[GROUP_SIZE] = {n};
427 MKL_INT p_k[GROUP_SIZE] = {k};
428 MKL_INT p_lda[GROUP_SIZE] = {lda};
429 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
430 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
432 double p_alpha[GROUP_SIZE] = {alpha};
433 double p_beta[GROUP_SIZE] = {beta};
435 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
436 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
438 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
439 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
440 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
442 std::vector<const double*> pp_A(batch_count,
nullptr);
443 std::vector<const double*> pp_B(batch_count,
nullptr);
444 std::vector<double*> pp_C(batch_count,
nullptr);
450 for (
int i = 0; i < batch_count; i++) {
451 pp_A[i] = A + i * m_k;
452 pp_B[i] = B + i * k_n;
453 pp_C[i] = C + i * m_n;
456 cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
457 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
458 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
460 for (
int i = 0; i < batch_count; ++i) {
461 gemm(stream, transa, transb, m, n, k, alpha,
462 A + i * m * k, lda, B + i * k * n, ldb,
463 beta, C + i * m * n, ldc);
468 bool trans,
int m,
int n,
double alpha,
469 const double *A,
int lda,
470 const double *X,
int incX,
471 double beta,
double *Y,
int incY) {
472 cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
473 A, lda, X, incX, beta, Y, incY);
476 bool trans,
int m,
int n,
477 double alpha,
const double *A,
int lda,
478 const double *X,
int incX,
479 double beta,
double *Y,
int incY,
int batch_count) {
480 for (
int i = 0; i < batch_count; ++i) {
481 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
482 X + i * (trans ? m : n) * incX, incX,
483 beta, Y + i * (trans ? n : m) * incY, incY);
487 int m,
int n,
double alpha,
488 const double *X,
int incX,
489 const double *Y,
int incY,
double *A,
int lda) {
490 cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
493 int m,
int n,
double alpha,
494 const double *X,
int incX,
495 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
496 for (
int i = 0; i < batch_count; ++i) {
497 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
498 A + i * lda * n, lda);
503 const double* X,
int incX,
504 const double* Y,
int incY,
506 *ret = cblas_ddot(n, X, incX, Y, incY);
509 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE 515 inline static cublasOperation_t
GetT(
bool t) {
516 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
521 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas set stream fail";
524 bool transa,
bool transb,
525 int m,
int n,
int k, half::half_t alpha,
526 const half::half_t *A,
int lda,
527 const half::half_t *B,
int ldb, half::half_t beta,
528 half::half_t *C,
int ldc) {
529 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 531 float alpha_f = float(alpha);
532 float beta_f = float(beta);
533 #if CUDA_VERSION >= 8000 535 GetT(transa), GetT(transb), m, n, k, &alpha_f,
536 A, CUDA_R_16F, lda, B, CUDA_R_16F,
537 ldb, &beta_f, C, CUDA_R_16F, ldc);
538 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
541 GetT(transa), GetT(transb), m, n, k, &alpha_f,
542 A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
543 ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
544 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
545 #endif // CUDA_VERSION >= 8000 547 LOG(FATAL) <<
"Require CUDA version >= 7.5!";
548 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 551 bool transa,
bool transb,
552 int m,
int n,
int k, half::half_t alpha,
553 const half::half_t *A,
int lda,
const half::half_t *B,
int ldb,
554 half::half_t beta, half::half_t *C,
int ldc,
int batch_count,
555 half::half_t **workspace) {
556 #if defined(__CUDACC__) && CUDA_VERSION >= 9000 557 int major = stream->
prop.major;
558 int minor = stream->
prop.minor;
560 if ((major > 5) || (major == 5 && minor >= 3)) {
561 const __half* A_h =
reinterpret_cast<const __half*
>(A);
562 const __half* B_h =
reinterpret_cast<const __half*
>(B);
563 __half* alpha_h =
reinterpret_cast<__half*
>(&alpha);
564 __half* beta_h =
reinterpret_cast<__half*
>(&beta);
565 __half* C_h =
reinterpret_cast<__half*
>(C);
567 GetT(transa), GetT(transb), m, n, k, alpha_h,
570 beta_h, C_h, ldc, m * n,
572 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: HgemmStridedBatched fail";
576 for (
int i = 0; i < batch_count; ++i) {
577 gemm(stream, transa, transb, m, n, k, alpha,
578 A + i * m * k, lda, B + i * k * n, ldb,
579 beta, C + i * m * n, ldc);
583 bool trans,
int m,
int n, half::half_t alpha,
584 const half::half_t *A,
int lda,
585 const half::half_t *X,
int incX, half::half_t beta,
586 half::half_t *Y,
int incY) {
587 LOG(FATAL) <<
"Not implmented!";
590 bool trans,
int m,
int n,
591 half::half_t alpha,
const half::half_t *A,
int lda,
592 const half::half_t *X,
int incX,
593 half::half_t beta, half::half_t *Y,
int incY,
int batch_count) {
594 LOG(FATAL) <<
"Not implmented!";
597 int m,
int n, half::half_t alpha,
598 const half::half_t *X,
int incX,
599 const half::half_t *Y,
int incY, half::half_t *A,
int lda) {
600 LOG(FATAL) <<
"Not implmented!";
603 int m,
int n, half::half_t alpha,
604 const half::half_t *X,
int incX,
const half::half_t *Y,
int incY,
605 half::half_t *A,
int lda,
int batch_count) {
606 LOG(FATAL) <<
"Not implmented!";
610 const half::half_t* X,
int incX,
611 const half::half_t* Y,
int incY,
613 LOG(FATAL) <<
"Not implmented!";
619 inline static cublasOperation_t
GetT(
bool t) {
620 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
625 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
628 bool transa,
bool transb,
629 int m,
int n,
int k,
float alpha,
630 const float *A,
int lda,
631 const float *B,
int ldb,
float beta,
634 GetT(transa), GetT(transb), m, n, k, &alpha,
635 A, lda, B, ldb, &beta, C, ldc);
636 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemm fail";
639 bool transa,
bool transb,
640 int m,
int n,
int k,
float alpha,
641 const float *A,
int lda,
const float *B,
int ldb,
642 float beta,
float *C,
int ldc,
int batch_count,
644 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 646 bool alloc_workspace =
false;
647 if (workspace == NULL) {
650 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
float*));
651 alloc_workspace =
true;
653 GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
655 const_cast<float*>(B), batch_count, k * n, stream);
656 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
658 GetT(transa), GetT(transb), m, n, k, &alpha,
659 (
const float**)workspace, lda,
660 (
const float**)(workspace + batch_count), ldb,
661 &beta, workspace + 2 * batch_count, ldc, batch_count);
662 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmBatched fail";
663 if (alloc_workspace) {
666 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 668 GetT(transa), GetT(transb), m, n, k, &alpha,
671 &beta, C, ldc, m * n,
673 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmStridedBatched fail";
675 for (
int i = 0; i < batch_count; ++i) {
676 gemm(stream, transa, transb, m, n, k, alpha,
677 A + i * m * k, lda, B + i * k * n, ldb,
678 beta, C + i * m * n, ldc);
680 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 683 bool trans,
int m,
int n,
float alpha,
684 const float *A,
int lda,
685 const float *X,
int incX,
float beta,
686 float *Y,
int incY) {
688 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
689 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemv fail";
692 bool trans,
int m,
int n,
693 float alpha,
const float *A,
int lda,
694 const float *X,
int incX,
695 float beta,
float *Y,
int incY,
int batch_count) {
696 for (
int i = 0; i < batch_count; ++i) {
697 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
698 X + i * (trans ? m : n) * incX, incX,
699 beta, Y + i * (trans ? n : m) * incY, incY);
703 int m,
int n,
float alpha,
704 const float *X,
int incX,
705 const float *Y,
int incY,
float *A,
int lda) {
707 m, n, &alpha, X, incX, Y, incY, A, lda);
708 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sger fail";
711 int m,
int n,
float alpha,
712 const float *X,
int incX,
713 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
714 for (
int i = 0; i < batch_count; ++i) {
715 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
716 A + i * lda * n, lda);
721 const float* X,
int incX,
722 const float* Y,
int incY,
725 CUBLAS_POINTER_MODE_DEVICE);
727 n, X, incX, Y, incY, ret);
728 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
730 CUBLAS_POINTER_MODE_HOST);
736 inline static cublasOperation_t
GetT(
bool t) {
737 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
742 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
745 bool transa,
bool transb,
746 int m,
int n,
int k,
double alpha,
747 const double *A,
int lda,
748 const double *B,
int ldb,
749 double beta,
double *C,
int ldc) {
751 GetT(transa), GetT(transb), m, n, k, &alpha,
752 A, lda, B, ldb, &beta, C, ldc);
753 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemm fail";
756 bool transa,
bool transb,
757 int m,
int n,
int k,
double alpha,
758 const double *A,
int lda,
const double *B,
int ldb,
759 double beta,
double *C,
int ldc,
int batch_count,
760 double **workspace) {
761 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 763 bool alloc_workspace =
false;
764 if (workspace == NULL) {
767 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
double*));
768 alloc_workspace =
true;
770 GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
772 const_cast<double*>(B), batch_count, k * n, stream);
773 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
775 GetT(transa), GetT(transb), m, n, k, &alpha,
776 (
const double**)workspace, lda,
777 (
const double**)(workspace + batch_count), ldb,
778 &beta, workspace + 2 * batch_count, ldc, batch_count);
779 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmBatched fail";
780 if (alloc_workspace) {
783 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 785 GetT(transa), GetT(transb), m, n, k, &alpha,
788 &beta, C, ldc, m * n,
790 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmStridedBatched fail";
792 for (
int i = 0; i < batch_count; ++i) {
793 gemm(stream, transa, transb, m, n, k, alpha,
794 A + i * m * k, lda, B + i * k * n, ldb,
795 beta, C + i * m * n, ldc);
797 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 800 bool trans,
int m,
int n,
double alpha,
801 const double *A,
int lda,
802 const double *X,
int incX,
803 double beta,
double *Y,
int incY) {
805 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
806 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemv fail";
809 bool trans,
int m,
int n,
810 double alpha,
const double *A,
int lda,
811 const double *X,
int incX,
812 double beta,
double *Y,
int incY,
int batch_count) {
813 for (
int i = 0; i < batch_count; ++i) {
814 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
815 X + i * (trans ? m : n) * incX, incX,
816 beta, Y + i * (trans ? n : m) * incY, incY);
820 int m,
int n,
double alpha,
821 const double *X,
int incX,
822 const double *Y,
int incY,
double *A,
int lda) {
824 m, n, &alpha, X, incX, Y, incY, A, lda);
825 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dger fail";
828 int m,
int n,
double alpha,
829 const double *X,
int incX,
830 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
831 for (
int i = 0; i < batch_count; ++i) {
832 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
833 A + i * lda * n, lda);
838 const double* X,
int incX,
839 const double* Y,
int incY,
842 CUBLAS_POINTER_MODE_DEVICE);
844 n, X, incX, Y, incY, ret);
845 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
847 CUBLAS_POINTER_MODE_HOST);
850 #endif // MSHADOW_USE_CUDA 853 return transpose ?
Shape2(shape[1], shape[0]) : shape;
856 template<
typename SV,
typename xpu,
857 bool transpose_left,
bool transpose_right,
typename DType>
858 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
864 #if MSHADOW_STAND_ALONE 866 if (!transpose_left && !transpose_right) {
868 }
else if (!transpose_left && transpose_right) {
870 }
else if (transpose_left && !transpose_right) {
880 CHECK(dst.
size(0) == sleft[0] && dst.
size(1) == sright[1] && sleft[1] == sright[0])
881 <<
"dot-gemm: matrix shape mismatch";
885 transpose_right , transpose_left,
886 transpose_right ? rhs.
size(0) : rhs.
size(1),
887 transpose_left ? lhs.
size(1) : lhs.
size(0),
888 transpose_right ? rhs.
size(1) : rhs.
size(0),
889 DType(scale * SV::AlphaBLAS()),
892 DType(SV::BetaBLAS()),
896 template<
typename SV,
typename xpu,
bool transpose_right,
typename DType>
897 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
907 CHECK(dst.
size(0) == sright[1] && lhs.
size(0) == sright[0])
908 <<
"dot-gemv: matrix shape mismatch" 909 <<
"dst: " << dst.
shape_ <<
"\n" 910 <<
"lhs: " << lhs.
shape_ <<
"\n" 911 <<
"rhs: " << sright <<
"\n";
915 rhs.
size(1), rhs.
size(0), scale * SV::AlphaBLAS(),
917 lhs.
dptr_, 1, SV::BetaBLAS(),
921 template<
typename SV,
typename xpu,
typename DType>
922 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
932 <<
"dot-ger: matrix shape mismatch" 933 <<
"dst: " << dst.
shape_ <<
"\n" 934 <<
"lhs: " << lhs.
shape_ <<
"\n" 936 if (SV::BetaBLAS() == 0.0f) {
948 #endif // MSHADOW_DOT_ENGINE_INL_H_ static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:596
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:106
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:91
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:739
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:406
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:64
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:297
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:622
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:691
DType * dptr_
pointer to the data
Definition: tensor.h:434
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:582
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:852
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:408
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:358
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:799
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:492
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:755
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:467
Definition: stream_gpu-inl.h:37
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:808
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:383
support for implicit GEMM operation
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:602
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:436
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:299
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:619
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:589
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:702
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:827
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:240
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:475
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:638
Definition: dot_engine-inl.h:70
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:710
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:62
device name GPU
Definition: tensor.h:46
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:119
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:736
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:859
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:154
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:519
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:523
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:505
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:550
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:416
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:744
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:403
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:125
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:307
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:515
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:216
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:486
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:377
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:719
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:923
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:392
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:113
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:836
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:366
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:898
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:501
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:819
overloaded + operator between half_t and bf16_t
Definition: base.h:334
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:682
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:84
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:76
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:441
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:518
Definition: dot_engine-inl.h:78
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:608
general tensor
Definition: tensor.h:420
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:99
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:627
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:82
static bool GetT(bool t)
Definition: dot_engine-inl.h:79
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:446
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:49
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:294