当前位置: 首页 > java >正文

层在init中只为创建线性层,forward的对线性层中间加非线性运算。且分层定义是为了把原本一长个代码的初始化和运算放到一个组合中。

init注意有几个层,这里有四个层,接下来在forward函数中会把这四个层都用上

model = SimpleViT().to(device),创建模型的时候,会弄出所有层

当outputs = model(images),模型被使用的时候才调用各个定义层时的forward函数,调用顺序如下,注意,每个forward都完全使用了这些层,且以mha多头注意力层为例,其中包含把输入变成qkv大矩阵的性层和其他线性层,线性层之间还有其他的非线性操作,会在forward使用,所以forward不是简单的把线形层组合,其中层与层之间的非线性运算就在这里

为什么要定义这么多类和forward函数呢?就是为了把一长串代码中的初始化放到整个类的初始化中,然后操作放到这个类定义的方法中

举例

对于这个类

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()assert d_model % num_heads == 0self.head_dim = d_model // num_headsself.num_heads = num_headsself.qkv = nn.Linear(d_model, d_model * 3)self.out = nn.Linear(d_model, d_model)def forward(self, x, *_):B, seq_len, d_model = x.shapeqkv = self.qkv(x)qkv = qkv.view(B, seq_len, 3, self.num_heads, self.head_dim)q, k, v = qkv.unbind(dim=2)q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)attn = torch.softmax(scores, dim=-1)context = torch.matmul(attn, v)context = context.transpose(1, 2).reshape(B, seq_len, d_model)return self.out(context)

可以直接使用

mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
y_class = mha(x)

这等效为代码,

# 线性层
qkv_linear = nn.Linear(d_model, d_model*3)
out_linear = nn.Linear(d_model, d_model)# ----------------------------
# Multi-Head Attention 流程
# ----------------------------# 1. 线性映射得到 QKV
qkv = qkv_linear(x)  # (B, seq_len, 3*d_model)# 2. reshape 成 (B, seq_len, 3, num_heads, head_dim)
qkv = qkv.view(B, seq_len, 3, num_heads, head_dim)# 3. 拆分 q, k, v
q, k, v = qkv.unbind(dim=2)  # 每个 (B, seq_len, num_heads, head_dim)# 4. 转置到 (B, num_heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)# 5. 注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)  # (B, num_heads, seq_len, seq_len)# 6. softmax 得到注意力权重
attn = torch.softmax(scores, dim=-1)# 7. 加权求和得到上下文
context = torch.matmul(attn, v)  # (B, num_heads, seq_len, head_dim)# 8. 转置回 (B, seq_len, num_heads, head_dim)
context = context.transpose(1, 2)# 9. 拼回 d_model
context = context.reshape(B, seq_len, d_model)# 10. 输出线性映射
out = out_linear(context)  # (B, seq_len, d_model)

http://www.xdnf.cn/news/18301.html

相关文章:

  • 机械革命电竞控制台一直加载无法点击故障
  • MySQL事务及原理详解
  • 牛津大学xDeepMind 自然语言处理(3)
  • 工业电脑选得好生产效率节节高稳定可靠之选
  • C/C++ 与嵌入式岗位常见笔试题详解
  • Mac电脑上虚拟机共享文件夹权限问题
  • vscode连接docker
  • WIFI国家码修改信道方法_高通平台
  • 精品方案 | GCKontrol与OMNeT++联合仿真在机载网络性能分析中的应用
  • mvdr波束形成
  • Linux系统之部署nullboard任务管理工具
  • ios八股文 -- Objective-c
  • iOS 应用上架常见问题与解决方案,多工具组合的实战经验
  • Node.js中的Prisma应用:现代数据库开发的最佳实践
  • 单片机通信协议核心关系梳理笔记(UART/USART/232/485/SPI/12C/LIN/BLE/WIFI)
  • leetcode7二分查找_69 and 34
  • 链表的核心:“增删改查”
  • Nginx 负载均衡和缓存配置
  • 【软考架构】净室软件工程
  • Gin自定义Error中间件
  • SQL-leetcode— 2356. 每位教师所教授的科目种类的数量
  • 手机 浏览器调用摄像头扫描二维码Quagga
  • 2026 济南淀粉深加工展览会亮点:玉米科技与未来产业发展
  • 03-dockerfile
  • C++继承中的虚函数机制:从单继承到多继承的深度解析
  • 【ansible】2.实施ansible playbook
  • 机器学习算法核心总结
  • 静/动态库 IIC(arm) day58
  • 医疗问答应用:UniApp + Node.js + DeepSeek API
  • Objective-C 版本的 LiveEventBus 效果