当前位置: 首页 > ds >正文

PyTorch_张量拼接

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如:残差网络,注意力机制中都使用张量拼接。


torch.cat 函数的使用

可以将两个张量根据指定的维度拼接起来。

import torch 
import numpy as np def test01():data1 = torch.randint(0, 10, [3, 4, 5])data2 = torch.randint(0, 10, [3, 4, 5])print(data1.shape)print(data2.shape)# dim 对应的值可以是负数,可以通过list来思考# 按照第 0 维度进行拼接new_data = torch.cat([data1, data2], dim = 0)  # 是列表print(new_data.shape)# 按照第 1 维度进行拼接new_data = torch.cat([data1, data2], dim = 1)print(new_data.shape)# 按照第 2 维度进行拼接new_data = torch.cat([data1, data2], dim = 2)print(new_data.shape)if __name__ == "__main__":test01() 

torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来,或者组合成新的元素。叠加的意思:当两个元素叠在一起,我们就将这两个元素当作一个元素。

import torch 
import numpy as np def test01():data1 = torch.randint(0, 10, [2, 3])data2 = torch.randint(0, 10, [2, 3])print(data1)print(data2)# 将两个张量 stack 叠加起来,像 cat 一样指定维度# 1. 按照第0维度进行叠加new_data = torch.stack([data1, data2], dim=0)print(new_data.shape)# 2. 按照第1维度进行叠加new_data = torch.stack([data1, data2], dim=1)print(new_data)# 3. 按照第2维度进行叠加new_data = torch.stack([data1, data2], dim=2)print(new_data)if __name__ == "__main__":test01() 
http://www.xdnf.cn/news/4024.html

相关文章:

  • ES6入门---第三单元 模块四:Set和WeakSet
  • SQL手工注入(DVWA)
  • 「Mac畅玩AIGC与多模态17」开发篇13 - 条件判断与分支跳转工作流示例
  • 交互式智能体面临长周期决策和随机环境反馈交互等挑战 以及解决办法
  • 记录一次手动更新英特尔Management Engine固件的经历
  • Python绘制误差棒:深入解析数据的不确定性
  • 文章记单词 | 第62篇(六级)
  • W-TinyLFU缓存驱逐算法解析
  • Maven框架详解:构建与依赖管理的利器
  • 《奇迹世界起源》:宝箱工坊介绍!
  • MyBatis 核心类详解与架构解析:从入门到源码级理解
  • 《前端秘籍:SCSS阴影效果全兼容指南》
  • Linux的系统周期化任务
  • ES类的索引轮换
  • JVM——JVM是怎么实现invokedynamic的?
  • HttpPrinter 是一款功能强大的跨平台 Web 打印解决方案
  • C与指针——结构与联合
  • Feign的原理
  • cesium基础设置
  • xx外卖知识补充
  • 日语学习-日语知识点小记-进阶-JLPT-N1阶段(1):语法单词
  • Jetpack Compose 边距终极指南:Margin 和 Padding 的正确处理方式
  • 详细案例,集成算法
  • 高等数学第三章---微分中值定理与导数的应用(3.3泰勒(Taylor)公式)
  • JAVA组织/岗位拉取多段时间属性到一张表上时,时间段分隔问题
  • 解释一下NGINX的反向代理和正向代理的区别?
  • 【C++重载操作符与转换】下标操作符
  • Android学习总结之事件分发机制篇
  • Java大厂面试:Java技术栈中的核心知识点
  • 25.5.4数据结构|哈夫曼树 学习笔记