当前位置: 首页 > news >正文

Pytorch基础操作

在这里插入图片描述

面试的时候,PhD看我简历上面写了”熟悉pytorch框架“,然后就猛猛提问了有关于tensor切片的问题…当然是没答上来,因此在这里整理一下pytorch的一些基础编程语法,常看常新


PyTorch基础操作全解

一、张量初始化

PyTorch的核心数据结构是torch.Tensor,初始化方法灵活多样:

1. 基础初始化


import torch# 未初始化张量(内存中可能存在随机值)a = torch.empty(3, 2)  # 3x2的未初始化矩阵# 均匀分布随机数 [0,1)b = torch.rand(2, 3)   # 2x3随机矩阵# 全零矩阵(显式指定类型)c = torch.zeros(4, 3, dtype=torch.long)  # 4x3的长整型零矩阵# 从列表创建d = torch.tensor([5.5, 3])  # 直接数值初始化

2. 基于已有张量的初始化


x = torch.rand(2, 2)# 继承原有张量属性(形状/设备)new_tensor = x.new_ones(3, 3, dtype=torch.double)  # 3x3全1矩阵,继承x的设备# 正态分布(继承形状)like_tensor = torch.randn_like(x, dtype=torch.float)  # 与x同形的正态分布

二、张量属性与运算

1. 关键属性


print(x.dtype)   # 数据类型 torch.float32print(x.device)  # 存储设备 cpu/cuda:0print(x.shape)   # 等价于x.size()

2. 基本运算(加法/矩阵乘法/张量形状操作


# 加法(三种等价方式)result1 = a + bresult2 = torch.add(a, b)a.add_(b)  # in-place操作(会修改a)# 矩阵乘法mat1 = torch.randn(2, 3)mat2 = torch.randn(3, 2)product = torch.mm(mat1, mat2)  # 2x2结果矩阵# 形状操作reshaped = x.view(4)    # 展平为1D(必须连续内存)resized = x.reshape(-1) # 自动推断维度(处理非连续内存)

view(-1)自动推断维度


# 输入序列 (batch=2, seq_len=5, features=10)seq_data = torch.randn(2, 5, 10)# 转换为(batch*seq_len, features)reshaped = seq_data.view(-1, 10)  # 形状[10, 10]print(reshaped.shape)            # torch.Size([10, 10])

3. 类型转换


float_tensor = x.to(torch.float64)  # 显式转换类型gpu_tensor = x.cuda()               # 转移至GPU

三、高级切片与索引

1. 三维张量切片(面试题解析)

假设有张量 tensor = torch.randn(5, 4, 6)

• 第一个维度取第一个元素:tensor[0](等价于tensor[0, :, :]

• 第二个维度取全部元素::...

• 第三个维度取奇数索引元素:1::2(从索引1开始,步长2)

完整解:


result = tensor[0, :, 1::2]  # shape变为 (4, 3)

2. 高级索引技巧


# 布尔掩码mask = tensor > 0.5selected = tensor[mask]# 组合索引indices = torch.tensor([0, 2])partial = tensor[:, indices, :]

四、与NumPy的互操作

1. 转换机制


# Tensor -> ndarraynumpy_array = tensor.numpy()  # CPU张量直接转换# ndarray -> Tensortorch_tensor = torch.from_numpy(numpy_array)# GPU张量转换cpu_tensor = gpu_tensor.cpu()numpy_from_gpu = cpu_tensor.numpy()

2. 内存共享特性:底层其实共享一套内存


a = torch.ones(3)b = a.numpy()a.add_(1)        # 修改张量print(b)         # [2., 2., 2.] 同步变化

五、扩展知识

一、自动求导机制

  1. 核心概念

PyTorch使用动态计算图实现自动微分:

• requires_grad:标记需要跟踪梯度的张量

• 计算图:记录张量间的运算关系(正向传播)

• backward():反向传播计算梯度

• grad属性:存储梯度值(默认会累积)

  1. 基础示例

x = torch.tensor(2., requires_grad=True)y = x**2 + 3*x  # 计算图建立y.backward()     # 反向传播print(x.grad)    # 输出:tensor(7.) # 导数计算:dy/dx = 2x + 3 → 2*2 + 3 = 7
  1. 梯度累积特性
    第二次反向传播前必须清除梯度

# 第二次反向传播前必须清除梯度x.grad.zero_()    # 梯度清零y = x**3y.backward()print(x.grad)     # 3x² → 3*(2)^2 = 12
  1. 非标量梯度处理

# 多输出系统需要指定gradient参数x = torch.randn(3, requires_grad=True)y = x * 2v = torch.tensor([0.1, 1.0, 0.001], dtype=torch.float)y.backward(v)     # 加权反向传播print(x.grad)     # 输出:tensor([0.2000, 2.0000, 0.0020])
  1. 梯度控制上下文

# 禁用梯度计算(节约内存)with torch.no_grad():inference = x * 2  # 不会记录计算图print(inference.requires_grad)  # False# 临时分离张量detached_x = x.detach()  # 创建无需梯度的副本

二、张量拼接操作

  1. 维度拼接 (torch.cat)
    必须保证维度匹配,dim=0(第一个维度拼接),dim=1(第二个维度拼接)

# 在现有维度上拼接a = torch.randn(2, 3)b = torch.randn(4, 3)concat_0 = torch.cat([a, b], dim=0)  # 形状(6,3)concat_1 = torch.cat([a, a], dim=1)   # 形状(2,6)
  1. 新增维度拼接 (torch.stack)
    新增维度拼接使用torch.stack,新增第零个维度维度拼接torch.stack([c, d], dim=0),新增最后一个维度torch.stack([c.T, d.T], dim=2)

# 创建新维度c = torch.randn(3, 4)d = torch.randn(3, 4)stack_0 = torch.stack([c, d], dim=0)  # 形状(2,3,4)stack_2 = torch.stack([c.T, d.T], dim=2)  # 形状(3,4,2)
  1. 拼接规则验证

try:# 维度不匹配报错invalid = torch.cat([a, b], dim=1) except RuntimeError as e:print(f"Error: {e}")  # 非拼接维度尺寸不一致

三、广播机制详解——两个张量维度不同时,自动对齐维度(复制出来直接补充)

  1. 广播规则

当两个张量维度不同时:

  1. 从右向左对齐维度

  2. 维度相容条件:

    • 维度大小相等

    • 其中一个维度为1

  3. 自动扩展:将尺寸为1的维度复制到匹配对方

  4. 典型示例


# 案例1:向量+标量a = torch.tensor([1, 2, 3])b = torch.tensor(5)print(a + b)  # [6,7,8]# 案例2:矩阵+向量matrix = torch.ones(2, 3)    # (2,3)vector = torch.arange(3)     # (3,) → (1,3) → (2,3)print(matrix + vector)# [[0,1,2],#  [0,1,2]] + [[1,1,1],#              [1,1,1]] = [[1,2,3],#                          [1,2,3]]# 案例3:三维广播tensor_3d = torch.ones(4, 3, 2)tensor_2d = torch.tensor([[0], [1], [2]])  # (3,1) → (4,3,2)result = tensor_3d + tensor_2dprint(result.shape)  # (4,3,2)
  1. 广播失败案例

try:a = torch.ones(3, 4)b = torch.ones(2, 5)c = a + bexcept RuntimeError as e:print(f"Error: {e}")  # 无法广播

四、综合应用示例

  1. 梯度控制与广播的结合

with torch.no_grad():base = torch.ones(2, 2)delta = torch.tensor([1., 2.])  # 广播为(2,2)modified = base * deltaprint(modified)  # 无梯度跟踪# 输出:# tensor([[1., 2.],#         [1., 2.]])
  1. 拼接与自动求导

x = torch.tensor([1., 2.], requires_grad=True)y = torch.cat([x, x**2], dim=0)  # 拼接成[1,2,1,4]loss = y.sum()       # 1+2+1+4 = 8loss.backward()print(x.grad)        # [1+2x, 2+0] → [1+2*1=3, 2+0=2]# 输出:tensor([3., 2.])

http://www.xdnf.cn/news/571807.html

相关文章:

  • cookie跨域共享踩的坑
  • sqli-labs第十八关——POST-UA注入
  • 使用MATLAB输出1000以内所有完美数
  • MoManipVLA-北京邮电-2025.3.17-移动操控-未完全开源
  • UML 时序图 使用案例
  • PostGIS实现栅格数据导出PNG应用实践【ST_AsPNG 】
  • 乘“4”而上,进取不止|Aloudata 的变与不变
  • 【专四 | 2022年真题】LANGUAGE USAGE逐题总结
  • dedecms织梦全局变量调用方法总结
  • 【OCCT+ImGUI系列】009-Geom2d-Geom2d_AxisPlacement
  • 使用Jenkins部署nodejs前端项目
  • 开源Vue表单设计器FcDesigner中组件联动的配置教程
  • 中国地图上标注颜色的方法
  • 食品饮料行业AI转型趋势分析与智能化解决方案探索​
  • 实战5:个性化数字艺术生成与销售
  • 目标检测 Lite-DETR(2023)详细解读
  • 信息系统项目管理师考前练习3
  • 怎样用 esProc 生成定长时间窗口列表并统计
  • 【Java高阶面经:微服务篇】7. 1秒响应保障:超时控制如何成为高并发系统的“救火队长”?
  • esp32cmini SK6812 2个方式
  • redis--redisJava客户端:Jedis详解
  • WebFuture:在银河麒麟系统中如何无中间件为WebFuture绑定域名、SSL证书
  • 【Linux】借助gcc源码修改,搜索头文件当前进展
  • springboot链接nacos测试
  • 分类预测 | Matlab实现PNN概率神经网络多特征分类预测
  • 数学实验(Matlab绘图基础)
  • 大量程粗糙度轮廓仪适用于哪些材质和表面?
  • 矿物绝缘加热电缆行业2025数据分析报告
  • 使用Gemini, LangChain, Gradio打造一个书籍推荐系统 (第一部分)
  • Dockerfile指令详解