当前位置: 首页 > backend >正文

深度学习|pytorch基本运算-广播失效

【1】引言

前序文章中,已经学习了pytorch基本运算中的生成随机张量、生成多维张量,以及张量的变形、加减和广播运算。

今天的文章在之前学习的基础上,进一步探索。

前序文章链接为:

深度学习|pytorch基本运算-CSDN博客

【2】广播失效

前序文章在最后给出了广播运算的基础代码:

# 导入包
import torch
# 生成多为维张量
y=torch.tensor([1,2,3])
z=torch.tensor([[3],[2],[1]
])
#打印
print('y=',y)
print('z=',z)
# 
a=y+z
print('a=',a)

在这个项目中,行向量y会沿着行广播(复制),列向量z会沿着列广播(复制),实际运行效果和下述代码一样:

# 导入包
import torch
# 生成多为维张量
y=torch.tensor([[1,2,3],[1,2,3],[1,2,3]])
z=torch.tensor([[3,3,3],[2,2,2],[1,1,1]
])
#打印
print('y=',y)
print('z=',z)
#
a=y+z
print('a=',a)

实际运行后的效果为:

图1  广播运行效果 

但实际上,如果稍微修改代码,就会有广播失效的情况:

# 导入包
import torch
# 生成多为维张量
y=torch.tensor([[1,2,3,1],[1,2,3,1],[1,2,3,1]])
z=torch.tensor([[3,3,3],[2,2,2],[1,1,1]
])
#打印
print('y=',y)
print('z=',z)
#
a=y+z
print('a=',a)

上述代码运行后,会直接报错:

    a=y+z
      ~^~
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

这里报错的意思是:在非单例维度1上,第一个矩阵a有4个数,第二个矩阵b有3个数,无法匹配。

在pytorch中,对于维度的规定是:

在竖直方向是第0维度, 代码中的y和z都有3行,匹配;

在水平方向是第1维度, 代码中的y和z分别有4列和3列,不匹配,无法广播。

需要注意到报错信息中,a(4)和b(3)是python语言报错的惯用写法,实际对应的就是y(4)和z(3)。

真实的不匹配来源是:矩阵y有4列数据,矩阵z有3列数据,矩阵z既不可能每一列都复制一遍来广播,也不可能任选一列复制来广播,所以无法广播;但对于单独的一列,则没有这样的烦恼,直接每一列都复制即可。比如下述代码:

# 导入包
import torch
# 生成多为维张量
y=torch.tensor([[1,2,3,1],[1,2,3,1],[1,2,3,1]])
z=torch.tensor([[3],[2],[1]
])
#打印
print('y=',y)
print('z=',z)
#
a=y+z
print('a=',a)

运行后的效果为:

 图2  单列多行广播运行效果 

 上述情况是第1维度即列的原因造成的无法广播,如果修改第0维度即行来测试,有如下代码:

# 导入包
import torch
# 生成多为维张量
y=torch.tensor([[1,2,3,1],[1,2,3,1],[1,2,3,1]])
z=torch.tensor([[3],[2],[1],[1]
])
#打印
print('y=',y)
print('z=',z)
#
a=y+z
print('a=',a)

代码运行后的报错为:

    a=y+z
      ~^~
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0

和前述分析的原因一样:3行和4行不对应,无法广播。

矩阵y有3行数据,矩阵z有4行数据,矩阵y既不可能每一行都复制一遍来广播,也不可能任选一行复制来广播,所以无法广播。

【3】总结

探索了pytorch的基本运算中广播失效的情况及其原因。 

http://www.xdnf.cn/news/10275.html

相关文章:

  • 方案精读:42页华为企业组织活力设计方案【附全文阅读】
  • 什么是trace,分布式链路追踪(Distributed Tracing)
  • C++ 建造者模式:简单易懂的设计模式解析
  • 保持本地 Git 项目副本与远程仓库完全同步
  • LeetCode 每日一题 2025/5/26-2025/6/1
  • DO指数GPU版本
  • Redis最佳实践——安全与稳定性保障之高可用架构详解
  • 双目相机深度的误差分析(基线长度和相机焦距的选择)
  • python中将一个列表样式的字符串转换成真正列表的办法以及json.dumps()和 json.loads()
  • 谷歌工作自动化——仙盟大衍灵机——仙盟创梦IDE
  • Cypress + React + TypeScript
  • Mybatis:灵活掌控SQL艺术
  • 分享两款使用免费软件,dll修复工具及DirectX修复工具
  • 西瓜书第五章——感知机
  • Qt程序添加调试输出窗口:CONFIG += console
  • Oracle中EXISTS NOT EXISTS的使用
  • 关于用Cloudflare的Zero Trust实现绕过备案访问国内站点说明
  • SEO长尾关键词优化进阶指南
  • springboot集成websocket给前端推送消息
  • Visual Studio笔记:MSVC工具集、MSBuild
  • 【HW系列】—日志介绍
  • Excel快捷键
  • ESP8266常用指令
  • LeetCode Hot100刷题——划分字母区间
  • 第十四篇:MySQL 运维中的故障场景还原与排查实战技巧
  • 华为计试——刷题
  • 计算机网络之路由表更新
  • 第四十一天打卡
  • Unity中的AudioManager
  • 完整解析 Linux Kdump Crash Kernel 工作原理和实操步骤