cuBLAS �?GPU 线性代数库
cuBLAS �?NVIDIA 官方�?BLAS(Basic Linear Algebra Subprograms)GPU 实现,提供高度优化的矩阵和向量运算,是深度学习框架底层计算的核心引擎�?
什么是 BLAS
BLAS 将线性代数运算分为三个层次:
| Level | 操作 | 复杂�? | 示例 |
|---|---|---|---|
| Level 1 | 向量-向量 | O(n) | AXPY: y = αx + y |
| Level 2 | 矩阵-向量 | O(n²) | GEMV: y = αAx + βy |
| Level 3 | 矩阵-矩阵 | O(n³) | GEMM: C = αAB + βC |
*GEMM(General Matrix Multiply�? 是深度学习中最核心的操作,全连接层、卷积(im2col后)、Attention 都归结为 GEMM�?
cuBLAS 初始�?
cpp
#include <cublas_v2.h>
cublasHandle_t handle;
cublasCreate(&handle);
// ... 使用 cuBLAS ...
cublasDestroy(handle);GEMM:矩阵乘�?
基本用法
cpp
// C = alpha * A * B + beta * C
// A: M×K, B: K×N, C: M×N
cublasStatus_t cublasSgemm(
cublasHandle_t handle,
cublasOperation_t transa, // CUBLAS_OP_N �?CUBLAS_OP_T
cublasOperation_t transb,
int m, int n, int k,
const float* alpha,
const float* A, int lda, // leading dimension of A
const float* B, int ldb,
const float* beta,
float* C, int ldc
);完整示例
cpp
#include <cublas_v2.h>
#include <cuda_runtime.h>
void matmul(int M, int N, int K) {
float alpha = 1.0f, beta = 0.0f;
// 分配设备内存
float *d_A, *d_B, *d_C;
cudaMalloc(&d_A, M * K * sizeof(float));
cudaMalloc(&d_B, K * N * sizeof(float));
cudaMalloc(&d_C, M * N * sizeof(float));
cublasHandle_t handle;
cublasCreate(&handle);
// 注意:cuBLAS 使用列主序(Column-Major�?
// 对于行主序的 C/C++ 数组,需要交�?A �?B 的顺�?
// C^T = B^T * A^T �? C(row-major) = A * B
cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K, // 注意:n, m, k(交换了 m �?n�?
&alpha,
d_B, N, // B 作为第一个矩�?
d_A, K, // A 作为第二个矩�?
&beta,
d_C, N);
cublasDestroy(handle);
cudaFree(d_A); cudaFree(d_B); cudaFree(d_C);
}列主序陷�?
cuBLAS 内部使用 Fortran 风格的列主序存储。C/C++ 使用行主序。处理方式:利用 (AB)^T = B^T A^T,交�?A、B 参数顺序,并交换 m、n 参数�?
批量 GEMM(Batched GEMM�?
深度学习中常需要对一批小矩阵做乘法(�?Batch 维度�?Attention):
cpp
// 批量矩阵乘法:C[i] = alpha * A[i] * B[i] + beta * C[i]
cublasSgemmBatched(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
(const float**)d_Aarray, lda, // 指针数组
(const float**)d_Barray, ldb,
&beta,
d_Carray, ldc,
batchCount
);
// 步长版本(矩阵连续存储时更高效)
cublasSgemmStridedBatched(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
d_A, lda, strideA, // strideA = m * k
d_B, ldb, strideB,
&beta,
d_C, ldc, strideC,
batchCount
);混合精度 GEMM
Ampere 及以上架构支�?TF32/FP16/BF16 Tensor Core 加速:
cpp
// 使用 cublasGemmEx 指定计算精度
cublasGemmEx(
handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
d_A, CUDA_R_16F, lda, // FP16 输入
d_B, CUDA_R_16F, ldb,
&beta,
d_C, CUDA_R_16F, ldc,
CUBLAS_COMPUTE_32F, // FP32 累加(精度更高)
CUBLAS_GEMM_DEFAULT_TENSOR_OP // 使用 Tensor Core
);
// FP8 (Hopper H100)
cublasGemmEx(
handle, CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k, &alpha,
d_A, CUDA_R_8F_E4M3, lda,
d_B, CUDA_R_8F_E4M3, ldb,
&beta,
d_C, CUDA_R_16F, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);精度选择指南
| 场景 | 推荐精度 | 速度提升 |
|---|---|---|
| 科学计算 | FP64 | 基准 |
| 深度学习训练 | BF16/TF32 | 10-20x |
| 深度学习推理 | FP16/INT8 | 20-40x |
| LLM 推理 | FP8 | 40-80x |
其他常用操作
cpp
// AXPY: y = alpha * x + y
cublasSaxpy(handle, N, &alpha, d_x, 1, d_y, 1);
// DOT: result = x · y
float result;
cublasSdot(handle, N, d_x, 1, d_y, 1, &result);
// NRML2: ||x||�?
float norm;
cublasSnrm2(handle, N, d_x, 1, &norm);
// SCAL: x = alpha * x
cublasSscal(handle, N, &alpha, d_x, 1);
// GEMV: y = alpha * A * x + beta * y
cublasSgemv(handle, CUBLAS_OP_N, m, n, &alpha, d_A, m, d_x, 1, &beta, d_y, 1);性能调优
使用 Tensor Core
cpp
// 启用 Tensor Core(默认已启用,但可显式设置)
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
// 禁用 Tensor Core(用于调试对比)
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);工作空间
cuBLAS 某些算法需要额外工作空间:
cpp
// 查询所需工作空间大小
size_t workspaceSize;
cublasGetWorkspace(handle, &workspaceSize);
// 分配并设置工作空�?
void* workspace;
cudaMalloc(&workspace, workspaceSize);
cublasSetWorkspace(handle, workspace, workspaceSize);�?Stream 集成
cpp
cudaStream_t stream;
cudaStreamCreate(&stream);
// �?cuBLAS 操作绑定�?Stream
cublasSetStream(handle, stream);
// 异步执行
cublasSgemm(handle, ...); // �?stream 上异步执�?
cudaStreamSynchronize(stream);GEMM 性能基准(A100 80GB�?
| 精度 | 矩阵大小 | 实测 TFLOPS | 峰值利用率 |
|---|---|---|---|
| FP64 | 8192×8192 | ~9.7 | ~97% |
| FP32 | 8192×8192 | ~19.2 | ~98% |
| TF32 | 8192×8192 | ~150 | ~96% |
| FP16 | 8192×8192 | ~300 | ~96% |
| INT8 | 8192×8192 | ~600 | ~97% |