当前位置: 首页 > ops >正文

layernorm backward CUDA优化分析

简述

本文面向拥有CUDA知识背景并有快速实现layernorm backward需求的读者,若想详细了解layernorm backward计算原理、优化细节请移步参考链接中的文章,本文更侧重于代码实现。如有高见请不吝赐教,谢谢!

很多大佬已经对layernorm_bwd原理、优化方法有过详细讲解(参考链接),这里不再赘述,只是对layernorm_bwd常用优化方法代码复现。

1. layernorm_bwd算法原理及cpu实现

  • layernorm_bwd公式推导:
    在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

template<typename T, typename T_ACC>
void layernorm_backward_cpu(T* dinput, T* dweight, T* dbias, T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{for(int b=0; b<batch; b++){for(int i=0; i<seq_len; i++){const T* doutput_offset = doutput + b * seq_len * hidden_dim + i * hidden_dim;T* dinput_offset = dinput + b * seq_len * hidden_dim + i * hidden_dim;const T* input_offset = input + b * seq_len * hidden_dim + i * hidden_dim;const T_ACC mean_val = mean[b * seq_len + i];const T_ACC rstd_val = rstd[b * seq_len + i]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int j = 0; j<hidden_dim; j++){T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[j] * doutput_offset[j];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int j = 0; j<hidden_dim; j++){T norm_bti = (input_offset[j] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[j] * doutput_offset[j];// gradient to biasdbias[j] += doutput_offset[j];// gradient to weightdweight[j] += norm_bti * doutput_offset[j];// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[j] += dval;}}}
}

2. layernorm_bwd cuda优化方法及实现

2.1 layernorm_bwd

  • 优化方法:v1版本是每个线程计算一行数据,即一共有batch*seq_len个线程,每个线程循环计算hidden_dim个数据;
template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel1(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int idx = blockIdx.x * blockDim.x + threadIdx.x;if(idx < batch * seq_len){const T* doutput_offset = doutput + idx * hidden_dim;T* dinput_offset = dinput + idx * hidden_dim;const T* input_offset = input + idx * hidden_dim;const T_ACC mean_val = mean[idx];const T_ACC rstd_val = rstd[idx]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=0; i<hidden_dim; i++){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=0; i<hidden_dim; i++){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];// gradient to biasatomicAdd(&(dbias[i]), doutput_offset[i]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(256, 1);dim3 grid((batch * seq_len) / block.x, 1);util::print_cuda_cfg(grid, block);layernorm_backward_kernel1<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.2 layernorm_fwd_v2

  • 优化方法:v2版本是每个warp计算一行数据,即一共有batch*seq_len个warp,每个warp循环计算hidden_dim个数据;warp内部会通过线程束洗牌指令计算出max值。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = warpSize / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel2(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = warpReduceSum<T>(dnorm_mean);dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];// gradient to biasatomicAdd(&(dbias[i]), doutput_offset[i]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * doutput_offset[i]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(32, 1);dim3 grid(1, batch * seq_len);layernorm_backward_kernel2<T, T_ACC><<<grid, block>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.3 layernorm_bwd_v3

  • 优化方法:基于v2版本仍采用32个线程计算一行数据,但在此版本中将doutput加载至smem中,避免对global memory多次访问。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = warpSize / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel3(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;extern __shared__ unsigned char tmp_smem[];T *smem = reinterpret_cast<T *>(tmp_smem);if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = warpReduceSum<T>(dnorm_mean);dnorm_norm_mean = warpReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){smem[tx] = doutput_offset[i];__syncthreads();T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * smem[tx];// gradient to biasatomicAdd(&(dbias[i]), smem[tx]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * smem[tx]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
	dim3 block(32, 1);dim3 grid(1, batch * seq_len);size_t smem_size = sizeof(T) * block.x;layernorm_backward_kernel3<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.4 layernorm_fwd_v4

  • 优化方法:基于v3版本,v4版本让1024个线程循环计算一行。
template <typename T>
__device__ T warpReduceSum(T val) {
#pragma unrollfor (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) {val += __shfl_xor_sync(0xFFFFFFFF, val, offset);}return val;
}template<typename T>
__device__ __inline__ T blockReduceSum(T val){__shared__ T shared[WARP_SIZE];__shared__ T ret;int warp_id = threadIdx.x / WARP_SIZE;int lane_id = threadIdx.x % WARP_SIZE;val = warpReduceSum(val);if(lane_id == 0){shared[warp_id] = val;}__syncthreads();val = (threadIdx.x < WARP_SIZE) ? shared[threadIdx.x] : (T)(0.0f);val = warpReduceSum(val);if (threadIdx.x == 0){ret = val;}__syncthreads();return ret;
}template<typename T, typename T_ACC>
__global__ void layernorm_backward_kernel4(T* dinput, T* dweight, T* dbias, const T* doutput,T* input, T* weight, T_ACC* mean, T_ACC* rstd,const int batch, const int seq_len, const int hidden_dim)
{int tx = threadIdx.x;int by = blockIdx.y;extern __shared__ unsigned char tmp_smem[];T *smem = reinterpret_cast<T *>(tmp_smem);if(by < batch * seq_len){const T* doutput_offset = doutput + by * hidden_dim;T* dinput_offset = dinput + by * hidden_dim;const T* input_offset = input + by * hidden_dim;const T_ACC mean_val = mean[by];const T_ACC rstd_val = rstd[by]; T dnorm_mean = 0.0f;T dnorm_norm_mean = 0.0f;for(int i=tx; i<hidden_dim; i+=blockDim.x){T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * doutput_offset[i];dnorm_mean += dnorm_i;dnorm_norm_mean += dnorm_i * norm_bti;}dnorm_mean = blockReduceSum<T>(dnorm_mean);dnorm_norm_mean = blockReduceSum<T>(dnorm_norm_mean);dnorm_mean = dnorm_mean / static_cast<T>(hidden_dim);dnorm_norm_mean = dnorm_norm_mean / static_cast<T>(hidden_dim);for(int i=tx; i<hidden_dim; i+=blockDim.x){smem[tx] = doutput_offset[i];__syncthreads();T norm_bti = (input_offset[i] - static_cast<T>(mean_val)) * static_cast<T>(rstd_val);T dnorm_i = weight[i] * smem[tx];// gradient to biasatomicAdd(&(dbias[i]), smem[tx]);// gradient to weightatomicAdd(&(dweight[i]), norm_bti * smem[tx]);// gradient to inputT dval = 0.0f;dval += dnorm_i;dval -= dnorm_mean;dval -= norm_bti * dnorm_norm_mean;dval *= rstd_val;dinput_offset[i] += dval;}}
}
 	dim3 block(1024, 1);dim3 grid(1, batch * seq_len);size_t smem_size = sizeof(T) * block.x;util::print_cuda_cfg(grid, block);layernorm_backward_kernel4<T, T_ACC><<<grid, block, smem_size>>>(dinput_gpu, dweight_gpu, dbias_gpu, doutput_gpu, input_gpu, weight_gpu, mean_gpu, rstd_gpu, batch, seq_len, hidden_dim);

2.5 layernorm_bwd其他优化方法

v4版本的性能瓶颈是对dbias和dweight进行atomicAdd计算,这样对于dbias和dweight每一个内存位置都有batch * seq_len个线程串行的进行累加计算,是较为耗时的操作。因此可以让block(1024, 1)计算多行,先将每个block负责计算行的smem[tx]和norm_bti × smem[tx]结果累加到寄存器中,然后再将多个block存在寄存器中的值进行atomicAdd计算,这样可以减少需要执行atomicAdd线程的数量,减少串行执行操作,从而提升性能。

3. layernorm_bwd 不同版本性能对比

数据类型及规模: FP32 16 64 2048
硬件平台:A100-SXM

layernorm_bwd versioncycle
layernorm_bwd7482424
layernorm_bwd251740
layernorm_bwd253976
layernorm_bwd98369

参考链接

序号链接备注
1https://zhuanlan.zhihu.com/p/694974164layernorm cuda 代码实现
2https://www.jianshu.com/p/db89d62e1974layernorm 反向推导公式
http://www.xdnf.cn/news/17093.html

相关文章:

  • linux nfs+autofs
  • mq_unlink系统调用及示例
  • Java开发时出现的问题---并发与资源管理深层问题
  • 在具身智能火热加持下,看 2025 年机器人学术年会中的热点主题。PNP机器人展示力控、灵巧手捕捉等案例。
  • Android Studio下载及安装配置
  • 计算机视觉的四项基本任务辨析
  • Android audio之 AudioDeviceInventory
  • 飞算JavaAI需求转SpringBoot项目:从零到一的沉浸式开发之旅
  • 人工智能之数学基础:利用全概率公式如何将复杂事件转为简单事件
  • 学习游戏制作记录(将各种属性应用于战斗以及实体的死亡)8.5
  • DM8日常运维命令总结(四)
  • Go语言 string
  • 数据结构——双向链表
  • Linux 调度器函数sched_*系统调用及示例
  • 【音视频】WebRTC 一对一通话-信令服
  • Go语言实战案例:使用context控制协程取消
  • 算法训练之哈希表
  • Java后端高频面试题
  • React在使用create-react-app创建项目慢的解决办法
  • python的高校考研交流系统
  • 基于ARM+FPGA多通道超声信号采集与传输系统设计
  • 广州客户 戴尔R720服务器 liunx系统 RAID5无损升级扩容
  • 注意点:Git 从安装到分支协作、冲突解决的完整步骤 ---待修改,没看这个步骤,需要重新整理步骤
  • JavaWeb(苍穹外卖)--学习笔记17(Websocket)
  • 国产三防平板电脑是什么?三防平板推荐
  • 前端包管理器深度对比
  • VUE2 学习笔记18 路由守卫
  • Mysql使用Canal服务同步数据->ElasticSearch
  • 数据挖掘,到底是在挖掘什么?
  • Golang 基本数据类型