cuda算子--softmax算子与优化
文章目录
- 1.前言
- 2. CUDA实现
- 2.1 实现步骤
1.前言
在二分类问题中,常常需要得到输出的概率值,通过设置一个阈值来进行二分类。而使用softmax函数,将多个线性输出的结果,转化为多个概率值,这些概率值的加和为1。
softmax函数的计算公式:
Softmax(xi)=exp(xi−max)∑i=0n(xi−max)Softmax(x_i) = \frac{exp(x_i - max)}{\sum_{i=0}^{n}(x_i - max)} Softmax(xi)=∑i=0n(xi−max)exp(xi−max)
在pyorch softmax api介绍中有详细的介绍;
点击这里可以查看pytorch实现源码。
2. CUDA实现
2.1 实现步骤
- 求每行的最大值
- 求exp_sum
- 求每个点在概率值
CUDA实现策略,原始策略让一个SM处理一整行,但当channel维度较小时,SM多余的资源可以用于处理batch维度。对threadblock采用二维设计,x维度处理softmax,y维度和grid维度处理batch。
具体的代码实现如下:
template <typename E> struct Add {__device__ __forceinline__ E operator()(E left, E right) {return left + right;}constexpr static E oob_val = 0;
};template <typename E> struct Max {__device__ __forceinline__ E operator()(E left, E right) {return max(left, right);}constexpr static E oob_val = -INFINITY;
};template <typename E, typename OP>
__device__ __forceinline__ E warp_reduce(E value, OP op) {E tmp = __shfl_xor_sync(0, value, 1, 2);tmp = op(tmp, value);E tmp2 = __shfl_xor_sync(0, tmp, 2, 4);tmp2 = op(tmp2, tmp);tmp = __shfl_xor_sync(0, tmp2, 4, 8);tmp = op(tmp, tmp2);tmp2 = __shfl_xor_sync(0, tmp, 8, 16);tmp2 = op(tmp2, tmp);tmp = __shfl_xor_sync(0, tmp2, 16, 32);return op(tmp, tmp2);
}template <typename E, typename OP>
__device__ __forceinline__ E blockReduce(E *sdata, E value, OP op) {E rel_val = warp_reduce(value, op);if (threadIdx.x % 32 == 0) {sdata[threadIdx.x / 32] = rel_val;}__syncthreads();rel_val =warp_reduce(threadIdx.x % 32 < blockDim.x / 32 ? sdata[threadIdx.x % 32]: OP::oob_val,op);return rel_val;
}__device__ __host__ __forceinline__ int upAlign(int x, int base) {return (x + base - 1) / base * base;
}__device__ __forceinline__ float4 operator-(const float4 &hr,const float4 &hl) {return make_float4(hr.x - hl.x, hr.y - hl.y, hr.z - hl.z, hr.w - hl.w);
}__device__ __forceinline__ float4 &operator+=(float4 &hr, const float4 &hl) {hr = make_float4(hr.x - hl.x, hr.y - hl.y, hr.z - hl.z, hr.w - hl.w);return hr;
}__device__ __forceinline__ float4 operator/(const float4 &hr,const float4 &hl) {return make_float4(hr.x / hl.x, hr.y / hl.y, hr.z / hl.z, hr.w / hl.w);
}__device__ __forceinline__ float4 exp(const float4 &val) {return make_float4(exp(val.x), exp(val.y), exp(val.z), exp(val.w));
}__device__ __forceinline__ float4 max(const float4 &hr, const float4 &hl) {return make_float4(max(hr.x, hl.x), max(hr.y, hl.y), max(hr.z, hl.z),max(hr.w, hl.w));
}__global__ void kernel(float *dst, float *src, uint32_t outer_size,uint32_t inner_size, uint32_t dim_size) {extern __shared__ char smem[];auto sdata = reinterpret_cast<float *>(smem);const uint32_t outer_stride = inner_size * dim_size;for (uint32_t outer_index = blockIdx.x; outer_index < outer_size;outer_index += gridDim.x) {const uint32_t outer_offset = outer_index * outer_stride;for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y;inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {const uint32_t data_offset = outer_offset + inner_index;float *shm = sdata + blockIdx.y * 32;float *isrc = src + data_offset;float4 max_val_v4 = {-INFINITY, -INFINITY, -INFINITY, -INFINITY};for (int w = threadIdx.x * 4; w < upAlign(dim_size, 4);w += blockDim.x * 4) {float4 sv;// 处理非4对齐位置if (dim_size - w < 4) {for (int i = 0; i < 4; i++) {reinterpret_cast<float *>(&sv)[i] =i < dim_size - w ? isrc[w + i] : -INFINITY;}} else {sv = make_float4(isrc[w], isrc[w + 1], isrc[w + 2],isrc[w + 3]);// sv = *((float4 *)(isrc + w));}max_val_v4 = max(sv, max_val_v4);}float rel_max_val = max(max(max_val_v4.x, max_val_v4.y),max(max_val_v4.z, max_val_v4.w));rel_max_val = blockReduce(shm, rel_max_val, Max<float>{});max_val_v4 = {rel_max_val, rel_max_val, rel_max_val, rel_max_val};float4 exp_sum_v4 = {0.0f, 0.0f, 0.0f, 0.0f};for (int w = threadIdx.x * 4; w < upAlign(dim_size, 4);w += blockDim.x * 4) {// 处理非4对齐位置if (dim_size - w < 4) {for (int i = 0; i < dim_size - w; i++) {reinterpret_cast<float *>(&exp_sum_v4)[i] =exp(isrc[w + i] - rel_max_val);}} else {float4 sv;sv = make_float4(isrc[w], isrc[w + 1], isrc[w + 2],isrc[w + 3]);// sv = *((float4 *)(isrc + w));exp_sum_v4 += exp(sv - max_val_v4);}}float rel_exp_sum =exp_sum_v4.x + exp_sum_v4.y + exp_sum_v4.z + exp_sum_v4.w;rel_exp_sum = blockReduce(shm, rel_exp_sum, Add<float>{});exp_sum_v4 = {rel_exp_sum, rel_exp_sum, rel_exp_sum, rel_exp_sum};for (int w = threadIdx.x * 4; w < upAlign(dim_size, 4);w += blockDim.x * 4) {// 处理非4对齐位置if (dim_size - w < 4) {for (int i = 0; i < dim_size - w; i++) {dst[w + i] =exp(isrc[w + i] - rel_max_val) / rel_exp_sum;}} else {float4 sv;sv = make_float4(isrc[w], isrc[w + 1], isrc[w + 2],isrc[w + 3]);// sv = *((float4 *)(isrc + w));*(reinterpret_cast<float4 *>(dst + w)) =exp(sv - max_val_v4) / exp_sum_v4;}}}}
}
原本希望使用ld.global.v4
,如注释所示写法,但是运行报错misaligned address
,各位知道啥原因吗?知道的评论区告诉我。