当前位置: 首页 > ds >正文

nn.Module模块介绍

nn.Module是 PyTorch 中所有神经网络模块的基类,用于构建可训练的模型,即构建一个新结构的模型。它是 PyTorch 神经网络的核心抽象,一个抽象类,使用时必须实现必要的抽象函数。

1. 使用方法:

举一个例子:手写数字图像识别,建立一个深度学习的框架,输入是28\times28的图像,输出是一个1\times10的向量,表示0~9的各个类别的可能性。

网络的架构:

使用nn.Module构建这个网络

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super().__init__()# 子模块定义self.conv1 = nn.Conv2d(1, 16, 3)  # 输入通道1,输出通道16,卷积核3x3self.pool = nn.MaxPool2d(2)       # 2x2最大池化self.fc = nn.Linear(16*13*13, 10) # 全连接层(假设输入图像为28x28)# 自定义参数(非子模块)self.scale = nn.Parameter(torch.tensor(1.0))  # 可训练标量参数,动态缩放输出结果def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 卷积 -> ReLU -> 池化x = x.view(-1, 16*13*13)             # 展平x = self.fc(x) * self.scale           # 全连接层 + 自定义参数缩放return x

模型在创建时,必须包含__init__和__forward__两个方法。

2. 特点

(1)参数管理

当在 nn.Module 的子类中将 nn.Parameter 或子模块(如 nn.Conv2d)赋值给类属性时,PyTorch 会记录这些对象到内部的 _parameters 或 _modules 字典中,确保它们参与梯度计算、设备移动(CPU/GPU)、参数保存/加载等关键操作。例如,例子中成员变量 self.conv1,self.fc , self.scale。

可以实现参数的自动跟踪

model = SimpleCNN()
print(list(model.named_parameters()))

输出:

scale; conv1.weight、conv1.bias;  fc_weight、fc_bias如下:

[('scale', Parameter containing:
tensor(1., requires_grad=True)), ('conv1.weight', Parameter containing:
tensor([[[[ 0.3146, -0.2337,  0.2631],[ 0.1649,  0.2865,  0.2307],[-0.0522, -0.2642, -0.1696]]],[[[ 0.0158,  0.3199,  0.0063],[ 0.0858,  0.1410, -0.0497],[-0.1104,  0.2964,  0.2612]]],[[[-0.1222, -0.1469,  0.0314],[-0.2020, -0.3159, -0.0970],[ 0.2853,  0.1428,  0.0119]]],[[[ 0.1217, -0.0545, -0.1806],[-0.0048,  0.1158,  0.1185],[-0.0908,  0.0012, -0.0098]]],[[[ 0.1017, -0.0518,  0.1661],[-0.1580, -0.0326,  0.3247],[-0.3255, -0.2731, -0.2454]]],[[[ 0.2273, -0.1849, -0.1432],[-0.3186,  0.0621, -0.2068],[ 0.0756, -0.3076, -0.2667]]],[[[ 0.2341,  0.2008, -0.0361],[-0.3005, -0.1754, -0.3298],[-0.2160, -0.3142,  0.3064]]],[[[-0.2293, -0.1122, -0.1528],[ 0.2064,  0.0754, -0.2762],[ 0.2740, -0.0463, -0.1822]]],[[[ 0.2774,  0.0322, -0.1532],[-0.0482, -0.0678, -0.2401],[-0.0318,  0.2358, -0.2187]]],[[[ 0.1396,  0.1801,  0.1789],[-0.1797, -0.1715, -0.3309],[-0.1572,  0.0549,  0.0577]]],[[[-0.3022,  0.2383,  0.1073],[-0.0813,  0.2904, -0.2532],[-0.0321,  0.0273, -0.2783]]],[[[ 0.2397,  0.3167, -0.2939],[-0.2852, -0.2542,  0.1281],[ 0.0433,  0.2920,  0.2629]]],[[[ 0.0573, -0.0992, -0.2561],[ 0.1158,  0.2102, -0.1286],[-0.3075,  0.0806,  0.2279]]],[[[-0.2582,  0.2342, -0.2332],[-0.2627,  0.2822,  0.2278],[ 0.1213, -0.1526, -0.1611]]],[[[-0.0150,  0.3245, -0.1438],[ 0.0012,  0.1359,  0.2652],[ 0.1046,  0.1012, -0.2422]]],[[[-0.0178,  0.3177,  0.1215],[ 0.0338, -0.1513,  0.2207],[ 0.1846,  0.0616, -0.0704]]]], requires_grad=True)), ('conv1.bias', Parameter containing:
tensor([-0.3028,  0.2742,  0.0908,  0.0770,  0.0357,  0.1591,  0.1625, -0.0185,0.0871,  0.2598,  0.2732, -0.0111,  0.2493, -0.1319, -0.1072, -0.0537],requires_grad=True)), ('fc.weight', Parameter containing:
tensor([[ 0.0039, -0.0049, -0.0023,  ..., -0.0102, -0.0178, -0.0031],[-0.0039, -0.0058, -0.0025,  ..., -0.0030, -0.0131,  0.0092],[ 0.0077, -0.0068,  0.0059,  ...,  0.0078,  0.0055,  0.0096],...,[-0.0038,  0.0079, -0.0186,  ..., -0.0171, -0.0047,  0.0003],[-0.0056, -0.0179,  0.0017,  ..., -0.0092, -0.0189,  0.0128],[ 0.0144, -0.0057,  0.0038,  ...,  0.0152, -0.0043,  0.0025]],requires_grad=True)), ('fc.bias', Parameter containing:
tensor([-0.0178, -0.0058,  0.0016,  0.0112,  0.0151, -0.0164,  0.0127,  0.0060,-0.0175, -0.0156], requires_grad=True))]进程已结束,退出代码为 0

 设备移动统一管理

model.to('cuda')  # 所有参数和子模块自动移至GPU
print(model.weight.is_cuda)  # True

梯度计算自动启用

loss = model(x).sum()
loss.backward()  # 所有注册的参数自动计算梯度
print(model.weight.grad is not None)  # True

有些参数是需要手动注册的,才能实现自动的管理:参数/模块是通过列表或字典动态生成的,需要手动注册,在__init__部分进行注册。

class DynamicModel(nn.Module):def __init__(self):super().__init__()self.params_list = nn.ParameterList([nn.Parameter(torch.randn(10)) for _ in range(5)])  # 自动注册self.params_dict = nn.ModuleDict({'p1': nn.Linear(10, 5)})                              # 自动注册# 普通Python容器内的参数需手动注册self.custom_list = [nn.Parameter(torch.randn(10))]for i, param in enumerate(self.custom_list):self.register_parameter(f'custom_{i}', param)  # 手动注册

(2)模块嵌套

新构建的网络模型要在此处进行定义。

class ComplexModel(nn.Module):def __init__(self):super().__init__()self.conv_block = nn.Sequential(nn.Conv2d(3, 16, 3),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Linear(16*13*13, 10)  # 假设输入图像为28x28def forward(self, x):x = self.conv_block(x)x = x.view(x.size(0), -1)  # 展平return self.classifier(x)

(3)模型保存与加载

# 保存
torch.save(model.state_dict(), 'model.pth')# 加载
new_model = MyModel()
new_model.load_state_dict(torch.load('model.pth'))

(4)钩子(Hooks)

可进行调试或特征提取。

def forward_hook(module, input, output):print(f"Layer {module.__class__.__name__} output shape: {output.shape}")model.conv_block.register_forward_hook(forward_hook)  # 注册钩子

3. 注意事项

不要直接调用 forward()应该用 model(x)(PyTorch 会自动处理钩子和梯度)。模块命名唯一,子模块名称不能重复(如两个 self.fc 会覆盖)。

http://www.xdnf.cn/news/18242.html

相关文章:

  • USB 2.0声卡
  • 考研复习-操作系统-第一章-计算机系统概述
  • k8s-单主机Master集群部署+单个pod部署lnmp论坛服务(小白的“升级打怪”成长之路)
  • 什么是GD库?PHP中7大类64个GD库函数用法详解
  • 【撸靶笔记】第五关:GET - Double Injection - Single Quotes - String
  • Qt——主窗口 mainWindow
  • GaussDB常用术语缩写及释义
  • 【Golang】:错误处理
  • AI Search进化论:从RAG到DeepSearch的智能体演变全过程
  • 第12章《学以致用》—PowerShell 自学闭环与实战笔记
  • 第七十七章:多模态推理与生成——开启AI“从无到有”的时代!
  • 计算机程序编程软件开发设计之node..js语言开发的基于Vue框架的选课管理系统的设计与实现、基于express框架的在线选课系统的设计与实现
  • Jenkins - CICD 注入环境变量避免明文密码暴露
  • Python中f - 字符串(f-string)
  • Hadoop入门
  • 前端基础知识版本控制系列 - 05( Git 中 HEAD、工作树和索引之间的区别)
  • 图论水题4
  • 写作路上的迷茫与突破
  • java_spring boot 中使用 log4j2 及 自定义layout设置示例
  • NestJS 手动集成TypeORM
  • 关于第一次接触Linux TCP/IP网络相关项目
  • Docker入门:容器化技术的第一堂课
  • python---装饰器
  • 在线编程题目之小试牛刀
  • [每周一更]-(第155期):Go 1.25 发布:新特性、技术思考与 Go vs Rust 竞争格局分析
  • 回溯剪枝的 “减法艺术”:化解超时危机的 “救命稻草”(一)
  • 机器学习算法篇(十三)------词向量转化的算法思想详解与基于词向量转换的文本数据处理的好评差评分类实战(NPL基础实战)
  • 微服务之间的调用需要走网关么?
  • Linux Shell定时检查日期执行Python脚本
  • Python数据类型转换详解:从基础到实践