【プログラム解析】cublasSgemm行列乗算の詳細説明
Detailed Explanation Cublassgemm Matrix Multiplication
cublasStatus_t cublasSgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float *alpha, const float *A, int lda, const float *B, int ldb, const float *beta, float *C, int ldc)
CublasSgemmの公式APIドキュメントリンク https://docs.nvidia.com/cuda/cublas/index.html
- ドキュメントによると、cublasSgemmがC = alpha * op(A)* op(B)+ beta * Cの行列の乗算-加算演算を完了したことがわかります。
- アルファとベータはスカラー、A BCは列優先のストレージマトリックスです
- transaパラメーターがCUBLAS_OP_Nの場合、op(A)= A、CUBLAS_OP_Tの場合、op(A)= Aが転置されます。
- transbのパラメーターがCUBLAS_OP_Nの場合、op(B)= B、CUBLAS_OP_Tの場合、op(B)= B転置
APIの行列パラメーターもAB Cで表されるため、次の例の行列A Bと混同しないように、cublasSgemmのパラメーターを次のように調整します。
- Aは乗法左行列と呼ばれます
- Bは乗法右行列と呼ばれます
- Cは結果行列と呼ばれます
したがって、alpha = 1およびbeta = 0の場合、cublasSgemmは計算を完了します:結果行列= op(左行列を乗算)* op(右行列を乗算)
C = AxBを解く
それらの中で(AはM行K列BはK行N列なので、CはM行N列です)
cublasSgemmtransaおよびtransbパラメーターを使用しないでください
C / C ++プログラムのAとBの入力は行ごとに格納されるため、この場合、cublasは実際にAとBの転置行列ATとBTを読み取ります。
線形代数の規則によれば、CT =(A x B)T = BT x ATであるため、cublasSgemmAPIのいくつかのパラメーターは次のように設定されます。
- 乗算の左行列はBT =パラメーターがBに設定され、乗算の右行列はAT =パラメーターがAに設定されます
- 結果マトリックスの行数は次のとおりですCTの行数=パラメーターはNに設定されます
- 結果マトリックスの列数は次のとおりですCTの列数=パラメーターはMに設定されます
- 乗算の左行列列と乗算の右行列行=パラメーターをKに設定
- 乗算左行列BTの主次元(つまり、数行)=パラメーターはNに設定されます
- 乗算右行列ATの主次元(つまり、複数の行があります)=パラメーターはKに設定されます
- 結果の行列はパラメーターCに格納され、その主な次元(つまり、複数の行があります)=パラメーターはNに設定されます
cublasSgemm(handle、CUBLAS_OP_N、CUBLAS_OP_N、N、M、K、&alpha、d_b、N、d_a、K、&beta、d_c、N)
上記のパラメータに従ってcublasSgemmAPIを呼び出します(行列Aは行ごとにポインタd_aに格納され、行列Bは行ごとにポインタd_bに格納され、行列Cの格納スペースポインタd_cに格納されます)最後に、結果行列のストレージスペースd_cは次のとおりです。C= AxBの結果、cublasSgemm計算プロセス全体を次の図に示します。
cublasSgemm行列の乗算を解くプロセス
サンプルプログラム
#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define M 2 #define N 4 #define K 3 void printMatrix(float (*matrix)[N], int row, int col, bool reverse) { if (reverse) { int temp temp = row row = col col = temp } std::cout << ' address of matrix: ' << matrix << ' address of (matrix+1):' << (matrix + 1) << ' sizeof(matrix): ' << sizeof(matrix) << ' sizeof(*matrix): ' << sizeof(*matrix) << std::endl for(int i=0i運転結果
[38 44 50 56]
[83 98113128]
サンプルプログラムのcublasSgemm計算ソリューションプロセス