【プログラム解析】cublasSgemm行列乗算の詳細説明



Detailed Explanation Cublassgemm Matrix Multiplication



cublasStatus_t cublasSgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc)

CublasSgemmの公式APIドキュメントリンク https://docs.nvidia.com/cuda/cublas/index.html



  • ドキュメントによると、cublasSgemmがC = alpha * op(A)* op(B)+ beta * Cの行列の乗算-加算演算を完了したことがわかります。
  • アルファとベータはスカラー、A BCは列優先のストレージマトリックスです
  • transaパラメーターがCUBLAS_OP_Nの場合、op(A)= A、CUBLAS_OP_Tの場合、op(A)= Aが転置されます。
  • transbのパラメーターがCUBLAS_OP_Nの場合、op(B)= B、CUBLAS_OP_Tの場合、op(B)= B転置

APIの行列パラメーターもAB Cで表されるため、次の例の行列A Bと混同しないように、cublasSgemmのパラメーターを次のように調整します。

  • Aは乗法左行列と呼ばれます
  • Bは乗法右行列と呼ばれます
  • Cは結果行列と呼ばれます

したがって、alpha = 1およびbeta = 0の場合、cublasSgemmは計算を完了します:結果行列= op(左行列を乗算)* op(右行列を乗算)



C = AxBを解く

それらの中で(AはM行K列BはK行N列なので、CはM行N列です)

cublasSgemmtransaおよびtransbパラメーターを使用しないでください

C / C ++プログラムのAとBの入力は行ごとに格納されるため、この場合、cublasは実際にAとBの転置行列ATとBTを読み取ります。



線形代数の規則によれば、CT =(A x B)T = BT x ATであるため、cublasSgemmAPIのいくつかのパラメーターは次のように設定されます。

  • 乗算の左行列はBT =パラメーターがBに設定され、乗算の右行列はAT =パラメーターがAに設定されます
  • 結果マトリックスの行数は次のとおりですCTの行数=パラメーターはNに設定されます
  • 結果マトリックスの列数は次のとおりですCTの列数=パラメーターはMに設定されます
  • 乗算の左行列列と乗算の右行列行=パラメーターをKに設定
  • 乗算左行列BTの主次元(つまり、数行)=パラメーターはNに設定されます
  • 乗算右行列ATの主次元(つまり、複数の行があります)=パラメーターはKに設定されます
  • 結果の行列はパラメーターCに格納され、その主な次元(つまり、複数の行があります)=パラメーターはNに設定されます

cublasSgemm(handle、CUBLAS_OP_N、CUBLAS_OP_N、N、M、K、&alpha、d_b、N、d_a、K、&beta、d_c、N)

上記のパラメータに従ってcublasSgemmAPIを呼び出します(行列Aは行ごとにポインタd_aに格納され、行列Bは行ごとにポインタd_bに格納され、行列Cの格納スペースポインタd_cに格納されます)最後に、結果行列のストレージスペースd_cは次のとおりです。C= AxBの結果、cublasSgemm計算プロセス全体を次の図に示します。

cublasSgemm行列の乗算を解くプロセス

サンプルプログラム

#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define M 2 #define N 4 #define K 3 void printMatrix(float (*matrix)[N], int row, int col, bool reverse) { if (reverse) { int temp temp = row row = col col = temp } std::cout << ' address of matrix: ' << matrix << ' address of (matrix+1):' << (matrix + 1) << ' sizeof(matrix): ' << sizeof(matrix) << ' sizeof(*matrix): ' << sizeof(*matrix) << std::endl for(int i=0i

運転結果

[38 44 50 56]

[83 98113128]

サンプルプログラムのcublasSgemm計算ソリューションプロセス