当前位置: 首页 > news >正文

【深度学习基础】PyTorch Tensor生成方式及复制方法详解

目录

  • PyTorch Tensor生成方式及复制方法详解
    • 一、Tensor的生成方式
      • (一)从Python列表/元组创建
      • (二)从NumPy数组创建
      • (三)特殊初始化方法
      • (四)从现有Tensor创建
    • (五)高级初始化方法
    • 二、复制方法对比
      • (一) `torch.tensor()` vs `torch.from_numpy()`
      • (二) `.clone()` vs `.copy_()` vs `copy.deepcopy()`
      • (三) 深度拷贝(Deep Copy)
    • 三、核心区别总结
    • 四、最佳实践建议

PyTorch Tensor生成方式及复制方法详解

在PyTorch中,Tensor的创建和复制是深度学习开发的基础操作。本文将全面总结Tensor的各种生成方式,并深入分析不同复制方法的区别。


一、Tensor的生成方式

(一)从Python列表/元组创建

import torch# 直接创建Tensor
t1 = torch.tensor([1, 2, 3])          # 整型Tensor
t2 = torch.tensor([[1.0, 2], [3, 4]])  # 浮点型Tensor

(二)从NumPy数组创建

import numpy as nparr = np.array([1, 2, 3])
t = torch.from_numpy(arr)  # 共享内存

(三)特殊初始化方法

zeros = torch.zeros(2, 3)      # 全0矩阵
ones = torch.ones(2, 3)       # 全1矩阵
rand = torch.rand(2, 3)        # [0,1)均匀分布
randn = torch.randn(2, 3)      # 标准正态分布
arange = torch.arange(0, 10, 2) # 0-10步长为2

(四)从现有Tensor创建

x = torch.tensor([1, 2, 3])
x1 = x.new_tensor([4, 5, 6])  # 新Tensor(复制数据)
x2 = torch.zeros_like(x)       # 形状相同,全0
x3 = torch.randn_like(x)       # 形状相同,随机值

(五)高级初始化方法

eye = torch.eye(3)             # 3x3单位矩阵
lin = torch.linspace(0, 1, 5)  # 0-1等分5份
log = torch.logspace(0, 2, 3)  # 10^0到10^2等分3份

二、复制方法对比

(一) torch.tensor() vs torch.from_numpy()

方法数据源内存共享梯度传递数据类型
torch.tensor()Python数据不共享支持自动推断
torch.from_numpy()NumPy数组共享不支持保持一致
# 示例:内存共享验证
arr = np.array([1, 2, 3])
t = torch.from_numpy(arr)
arr[0] = 99  # 修改NumPy数组
print(t)      # tensor([99, 2, 3]),同步变化

(二) .clone() vs .copy_() vs copy.deepcopy()

方法内存共享梯度传递计算图保留使用场景
.clone()不共享保留梯度保留计算图需要梯度回传
.copy_()目标共享不保留破坏计算图高效覆盖数据
copy.deepcopy()不共享不保留不保留完全独立拷贝
# 示例:梯度传递对比
x = torch.tensor([1.], requires_grad=True)
y = x.clone()
z = torch.tensor([2.], requires_grad=True)
z.copy_(x)  # 覆盖z的值y.backward()  # 正常回传梯度到x
# z.backward() # 报错!copy_()破坏计算图

(三) 深度拷贝(Deep Copy)

import copyorig = torch.tensor([1, 2, 3])
deep_copied = copy.deepcopy(orig)  # 完全独立拷贝

三、核心区别总结

  1. 内存共享

    • from_numpy() 与NumPy共享内存
    • 视图操作(如view()/切片)共享内存
    • 其他方法均创建独立副本
  2. 梯度处理

    • .clone() 唯一保留梯度计算图
    • copy_() 会破坏目标Tensor的计算图
    • torch.tensor() 创建新计算图
  3. 使用场景

    • 需要梯度回传:使用.clone()
    • 高效数据覆盖:使用.copy_()
    • 完全独立拷贝:使用copy.deepcopy()
    • 与NumPy交互:使用from_numpy()/numpy()

四、最佳实践建议

  1. 优先使用torch.tensor()创建新Tensor
  2. 需要从NumPy导入数据且避免复制时用from_numpy()
  3. 在计算图中复制数据时必须使用.clone()
  4. 需要覆盖现有Tensor数据时使用.copy_()
  5. 调试时注意内存共享可能导致的意外修改
# 正确梯度传递示例
x = torch.tensor([1.], requires_grad=True)
y = x.clone() ** 2  # 保留计算图
y.backward()        # 梯度可回传到x
http://www.xdnf.cn/news/1321777.html

相关文章:

  • <数据集>遥感飞机识别数据集<目标检测>
  • 基于深度学习的车牌检测识别系统:YOLOv5实现高精度车牌定位与识别
  • Android Coil3视频封面抽取封面帧存Disk缓存,Kotlin(2)
  • 【LLM1】大型语言模型的基本生成机制
  • 华清远见25072班C语言学习day11
  • 当使用STL容器去存放数据时,是存放对象合适,还是存放对象指针(对象地址)合适?
  • 【C++】 using声明 与 using指示
  • Linux内存管理系统性总结
  • Orange的运维学习日记--45.Ansible进阶之文件部署
  • 获粤港澳大湾区碳足迹认证:遨游智能三防手机赋能绿色通信
  • LeetCode:无重复字符的最长子串
  • 实践笔记-VSCode与IDE同步问题解决指南;程序总是进入中断服务程序。
  • LAMP 架构部署:Linux+Apache+MariaDB+PHP
  • 规避(EDR)安全检测--避免二进制文件落地
  • 云计算- KubeVirt 实操指南:VM 创建 、存储挂载、快照、VMI全流程 | 容器到虚拟机(镜像转换/资源调度)
  • 前端处理导出PDF。Vue导出pdf
  • 王树森深度强化学习DRL(三)围棋AlphaGo+蒙特卡洛
  • STRIDE威胁模型
  • 新手向:Java方向讲解
  • Python实战--基于Django的企业资源管理系统
  • 块体不锈钢上的光栅耦合表面等离子体共振的复现
  • 后端通用基础代码
  • 在嵌入式单片机开发中,通过校验和或者校验码来比对程序版本好有何优劣势
  • 【OLAP】trino安装和基本使用
  • 【完整源码+数据集+部署教程】无人机目标检测系统源码和数据集:改进yolo11-efficientViT
  • Linux网络服务(一)——计算机网络参考模型与子网划分
  • Linux bash核心介绍及目录命令
  • Android中使用RxJava实现网络请求与缓存策略
  • Git-2025-0818
  • 数据结构:查找表