当前位置: 首页 > java >正文

PyTorch中“原地”赋值的思考

在开发一个PyTorch模块时,遇到了一个诡异的现象,将他描述出来就是下面这样:

f[..., :p_index - 1] = f[..., 1:p_index]

这个操作将f张量的部分数值进行左移,我在模型训练的时候还能正常跑,但是当我将模型部署到项目中时,这行代码报错了!

Traceback (most recent call last):File "<input>", line 1, in <module>
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

这个PyTorch报错是因为在执行操作时,输入张量和目标张量共享了同一块内存地址(存在内存重叠),导致PyTorch无法安全地完成原地(in-place)操作。

既然这样的话为什么在模型训练的时候不会这样呢?后面我仔细研究了一下午,发现了下面的原因:


当我们模型在训练阶段中,f的形状通常是(B,F)的形式存在的,而在部署的时候,作推理时数据通常是(1,F)的形式,所以会出现下面的情况:

# 创建高维张量(3维)
f_3d = torch.randn(16, 1, 25)
slice_3d = f_3d[..., 1:24]  # 源切片print("高维张量切片是否连续:")
print(slice_3d.is_contiguous())  # 输出 False# 创建一维张量对比
f_1d = torch.randn(1, 1, 25)
slice_1d = f_1d[..., 1:24]print("\n一维张量切片是否连续:")
print(slice_1d.is_contiguous())  # 输出 True

可以看到,当张量是维度大于1时,其在内存中是非连续存储的,而张量维度为1时,其在内存中是连续存储的。对于非连续张量,PyTorch会在赋值时隐式创建临时副本,避免内存覆盖。因此在进行原地赋值时不会报错。

最后,为了加强代码的鲁棒性,我在所有涉及这部分操作的代码后面加上了clone()函数。

f[..., :p_index - 1] = f[..., 1:p_index].clone()
http://www.xdnf.cn/news/3425.html

相关文章:

  • QT —— 信号和槽(带参数的信号和槽函数)
  • Qwen3 正式发布
  • Ethan独立开发产品日报 | 2025-04-30
  • Java中修饰类的关键字
  • [蓝桥杯 2021 省 AB] 砝码称重 Java
  • 【论文速递】2025年08周 (Robotics/Embodied AI/LLM)
  • Y1代码AC集
  • 坚鹏:平安保险集团《保险行业发展趋势与AI应用方法及案例》培训
  • 【Redis】Another Redis Desktop Manager 安装指南
  • 深入理解虚拟机与容器:原理、对比与应用场景分析
  • 动态规划简单题2
  • 算法-堆、排序算法、矩阵乘法
  • 面试手撕——迭代法中序遍历二叉树
  • 负载均衡深度实践:基于Nginx+Keepalived的高可用方案与Zabbix监控设计
  • Cesium Entity动态更新
  • 嵌入式AI还是一片蓝海
  • Day107 | 147.对链表进行插入排序 | 简单选择、冒泡、直接插入
  • 【专题五】位运算(2)
  • AXI中的out of order和interleaving的定义和两者的差别?
  • OSPF的路由
  • Go-web开发之社区功能
  • Java 中那些奇怪的空指针报错场景及解决方案NullPointerException
  • 【计算机视觉】语义分割:MMSegmentation:OpenMMLab开源语义分割框架实战指南
  • MySQL数据同步之Canal讲解
  • 2025年- H16-Lc124-169.多数元素(技巧)---java版
  • 7.0/Q1,GBD数据库最新文章解读
  • ClackyAI:下一代智能云开发环境的技术革新与实践价值
  • WPF使用依赖注入框架AutoMapper
  • 仿腾讯会议——服务器结构讲解
  • Matlab/Simulink - BLDC直流无刷电机仿真基础教程(四) - PWM调制模拟