当前位置: 首页 > ds >正文

门控线性单元GLU (Gated Linear Unit)

文章目录

    • 门控线性单元GLU (Gated Linear Unit)
      • 函数表达式
      • 与 Swish 的对比
      • PyTorch 中的 GLU 实现
      • TensorFlow 中的 GLU 实现

门控线性单元GLU (Gated Linear Unit)

  • 论文

    https://arxiv.org/abs/1612.08083

  • 门控线性单元(GLU)最初在《Language Modeling with Gated Convolutional Networks》提出,设计灵感来自门控控制,通过引入门控操作来控制信息的流动,它巧妙地将线性变换门控机制结合起来,通过可学习的门控信号来控制信息流,可以看做是引入了一种动态的选择极值,以在模型中选择性地传递信息

  • GLU 的有效性来源于其直观的工作流程:

    1. 双线性变换: 首先,输入 x 会并行地进行两次独立的线性变换。一次生成“候选”输出 (xW+b),这是潜在需要传递的信息。
    2. 门控过滤: 同时,另一个线性变换 (xV+c) 的结果会通过 Sigmoid 函数生成一个介于 0 和 1 之间的门控信号。这个门像一个智能开关,决定了“候选”输出中每个维度的信息应该有多少被保留,有多少应该被抑制。
    3. 残差友好: 由于门控输出的平均值大约为 0.5,这使得 GLU 具有一种天然的“残差”特性,有助于缓解深度网络训练中的梯度消失问题

函数表达式

  • GLU函数
    GLU(x)=(xW+b)⊗σ(xV+c)\begin{aligned} \mathrm{GLU(x)}=(xW+b)\otimesσ(xV+c) \end{aligned} GLU(x)=(xW+b)σ(xV+c)

    其中

    • x∈Rdx \in \mathbb{R}^dxRd 为输入向量
    • W、V∈Rd×dW、V \in \mathbb{R}^{d \times d}WVRd×db、c∈Rdb、c \in \mathbb{R}^dbcRd 为可学习的权重矩阵与偏置向量
    • ⊗\otimes 表示逐元素乘积(哈达玛乘积)
    • σ(⋅)\sigma(\cdot)σ() 为 sigmoid 门控,将后面 xV+cxV+cxV+c 值压缩到 (0, 1) 区间,作为门控信号,决定信息通过比例

与 Swish 的对比

  • 与swish对比

    特性SwishGLU
    参数标量 β\betaβ(固定或可学习)全连接权重 WWW、偏置 bbb(可学习)
    门控方式输入自身经过 sigmoid 缩放输入经线性变换后再经 sigmoid 门控
    参数量每通道 0/1 个标量每通道 d+1d+1d+1 个参数
    计算复杂度低(一次 sigmoid)高(一次矩阵乘 + sigmoid)
    表达能力中等

PyTorch 中的 GLU 实现

  • 代码(以 nn.GLU 为例,针对通道维度切分)

    注意:使用官方的GLU函数,输出维度是减半的

    import torch
    import torch.nn as nntorch.manual_seed(1024)batch_size = 8
    seq_len = 64
    d_model = 512x = torch.randn(batch_size, seq_len, d_model)# 官方 GLU 沿指定维度将输入一分为二
    glu = nn.GLU(dim=-1)          # dim 指定切分维度
    out = glu(x)                  # 输出 [batch_size, seq_len//2, d_model]print("Input shape :", x.shape)
    print("Output shape:", out.shape)"""输出"""
    Input shape : torch.Size([8, 64, 512])
    Output shape: torch.Size([8, 64, 256])
    

    若输入通道维度(seq_len)为偶数,可直接使用 nn.GLU(dim=channel_dim),此时将输入均分两份:前一半做值、后一半做门控。

  • 自定义 GLU(任意线性映射 + 门控)

    注意:输出维度可以

    import torch
    import torch.nn as nn
    torch.manual_seed(1024)class GLU(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.w1 = nn.Linear(d_in, d_out, bias=False)self.w2 = nn.Linear(d_in, d_out, bias=False)self.w3 = nn.Linear(d_out, d_in, bias=False)  # 可选:再投影回 d_indef forward(self, x):# x: [batch_size, seq_len, d_model]gate = torch.sigmoid(self.w2(x))   # [batch_size, seq_len, d_in]out  = self.w1(x) * gate           # [batch_size, seq_len, d_out]return self.w3(out)                # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = torch.randn(batch_size, seq_len, d_model)
    layer = GLU(d_in=d_model, d_out=d_ff)
    print(layer(x).shape)   # 维度不变"""输出"""
    torch.Size([8, 64, 512])
    

TensorFlow 中的 GLU 实现

  • 代码(tf.keras 自定义层)

    import tensorflow as tfclass GLU(tf.keras.layers.Layer):"""典型 Transformer-FFN 中的 GLU 层:GLU(x) = (x W_gate) ⊙ σ(x W_up)  再投影回 d_in,维度不变"""def __init__(self, d_in, d_out, **kwargs):super().__init__(**kwargs)self.d_in = d_inself.d_out = d_out# 两路线性映射self.w_gate = tf.keras.layers.Dense(d_out, use_bias=False)self.w_up   = tf.keras.layers.Dense(d_out, use_bias=False)self.w_down = tf.keras.layers.Dense(d_in, use_bias=False)def call(self, x):gate = tf.nn.sigmoid(self.w_gate(x))   # [batch_size, seq_len, d_out]up   = self.w_up(x)                    # [batch_size, seq_len, d_out]return self.w_down(gate * up)          # [batch_size, seq_len, d_in]# 使用示例
    batch_size = 8
    seq_len = 64
    d_model = 512
    d_ff = 4 * d_modelx = tf.random.normal([batch_size, seq_len, d_model])
    glu = GLU(d_in=d_model, d_out=d_ff)
    print(glu(x).shape)   # (4, 64, 512)  维度不变"""输出"""
    (8, 64, 512)
    

http://www.xdnf.cn/news/15854.html

相关文章:

  • Go语言流程控制(if / for)
  • 一小时学习Redis
  • websocket案例 599足球比分
  • 海森矩阵(Hessian Matrix)在SLAM图优化和点云配准中的应用介绍
  • 实战指南|智慧无人机安防系统搭建全流程解析
  • 深入理解Linux文件操作:stdin/stdout/stderr与C语言文件函数全解析
  • PDF 拆分合并PDFSam:开源免费 多文件合并 + 按页码拆分 本地处理
  • 突破性量子芯片问世:电子与光子首次集成,开启量子技术规模化应用新篇章
  • 暑期自学嵌入式——Day05补充(C语言阶段)
  • 二分答案之第 K 小/大
  • Visual Studio编译WPF项目生成的文件介绍
  • 服务器mysql数据的简单备份脚本
  • 二、Dify 版本升级教程(LInux-openeuler)
  • iOS OC 图片压缩
  • vue2 面试题及详细答案150道(101 - 120)
  • 国产替代:ASP4644在电信通信设备中的测试与应用前景
  • Java类:BigDecimal 的用法:乘法
  • IDEA 2020.1版本起下载JDK
  • Logback 配置的利器:深入理解<property>与<variable>
  • vue2 面试题及详细答案150道(21 - 40)
  • 闭包的定义和应用场景
  • Rust实战:高效对接Postman API
  • Spring中的SpEL是什么
  • Springboot3整合Elasticsearch8(elasticsearch-java)
  • [2025CVPR-目标检测方向]FSHNet:一种用于3D物体检测的全稀疏混合网络。
  • Hive数据仓库工具
  • 什么是高光谱相机,它与数码相机有什么区别?
  • 相机光学(五十)——Depth AF
  • RTKLIB读取星历文件,观测数据
  • 解决Flutter运行android提示Deprecated imperative apply of Flutter‘s Gradle plugins