nn.Module模块介绍
nn.Module是 PyTorch 中所有神经网络模块的基类,用于构建可训练的模型,即构建一个新结构的模型。它是 PyTorch 神经网络的核心抽象,一个抽象类,使用时必须实现必要的抽象函数。
1. 使用方法:
举一个例子:手写数字图像识别,建立一个深度学习的框架,输入是2828的图像,输出是一个1
10的向量,表示0~9的各个类别的可能性。
网络的架构:
使用nn.Module构建这个网络
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super().__init__()# 子模块定义self.conv1 = nn.Conv2d(1, 16, 3) # 输入通道1,输出通道16,卷积核3x3self.pool = nn.MaxPool2d(2) # 2x2最大池化self.fc = nn.Linear(16*13*13, 10) # 全连接层(假设输入图像为28x28)# 自定义参数(非子模块)self.scale = nn.Parameter(torch.tensor(1.0)) # 可训练标量参数,动态缩放输出结果def forward(self, x):x = self.pool(F.relu(self.conv1(x))) # 卷积 -> ReLU -> 池化x = x.view(-1, 16*13*13) # 展平x = self.fc(x) * self.scale # 全连接层 + 自定义参数缩放return x
模型在创建时,必须包含__init__和__forward__两个方法。
2. 特点
(1)参数管理
当在 nn.Module
的子类中将 nn.Parameter
或子模块(如 nn.Conv2d
)赋值给类属性时,PyTorch 会记录这些对象到内部的 _parameters
或 _modules
字典中,确保它们参与梯度计算、设备移动(CPU/GPU)、参数保存/加载等关键操作。例如,例子中成员变量 self.conv1,self.fc , self.scale。
可以实现参数的自动跟踪
model = SimpleCNN()
print(list(model.named_parameters()))
输出:
scale; conv1.weight、conv1.bias; fc_weight、fc_bias如下:
[('scale', Parameter containing:
tensor(1., requires_grad=True)), ('conv1.weight', Parameter containing:
tensor([[[[ 0.3146, -0.2337, 0.2631],[ 0.1649, 0.2865, 0.2307],[-0.0522, -0.2642, -0.1696]]],[[[ 0.0158, 0.3199, 0.0063],[ 0.0858, 0.1410, -0.0497],[-0.1104, 0.2964, 0.2612]]],[[[-0.1222, -0.1469, 0.0314],[-0.2020, -0.3159, -0.0970],[ 0.2853, 0.1428, 0.0119]]],[[[ 0.1217, -0.0545, -0.1806],[-0.0048, 0.1158, 0.1185],[-0.0908, 0.0012, -0.0098]]],[[[ 0.1017, -0.0518, 0.1661],[-0.1580, -0.0326, 0.3247],[-0.3255, -0.2731, -0.2454]]],[[[ 0.2273, -0.1849, -0.1432],[-0.3186, 0.0621, -0.2068],[ 0.0756, -0.3076, -0.2667]]],[[[ 0.2341, 0.2008, -0.0361],[-0.3005, -0.1754, -0.3298],[-0.2160, -0.3142, 0.3064]]],[[[-0.2293, -0.1122, -0.1528],[ 0.2064, 0.0754, -0.2762],[ 0.2740, -0.0463, -0.1822]]],[[[ 0.2774, 0.0322, -0.1532],[-0.0482, -0.0678, -0.2401],[-0.0318, 0.2358, -0.2187]]],[[[ 0.1396, 0.1801, 0.1789],[-0.1797, -0.1715, -0.3309],[-0.1572, 0.0549, 0.0577]]],[[[-0.3022, 0.2383, 0.1073],[-0.0813, 0.2904, -0.2532],[-0.0321, 0.0273, -0.2783]]],[[[ 0.2397, 0.3167, -0.2939],[-0.2852, -0.2542, 0.1281],[ 0.0433, 0.2920, 0.2629]]],[[[ 0.0573, -0.0992, -0.2561],[ 0.1158, 0.2102, -0.1286],[-0.3075, 0.0806, 0.2279]]],[[[-0.2582, 0.2342, -0.2332],[-0.2627, 0.2822, 0.2278],[ 0.1213, -0.1526, -0.1611]]],[[[-0.0150, 0.3245, -0.1438],[ 0.0012, 0.1359, 0.2652],[ 0.1046, 0.1012, -0.2422]]],[[[-0.0178, 0.3177, 0.1215],[ 0.0338, -0.1513, 0.2207],[ 0.1846, 0.0616, -0.0704]]]], requires_grad=True)), ('conv1.bias', Parameter containing:
tensor([-0.3028, 0.2742, 0.0908, 0.0770, 0.0357, 0.1591, 0.1625, -0.0185,0.0871, 0.2598, 0.2732, -0.0111, 0.2493, -0.1319, -0.1072, -0.0537],requires_grad=True)), ('fc.weight', Parameter containing:
tensor([[ 0.0039, -0.0049, -0.0023, ..., -0.0102, -0.0178, -0.0031],[-0.0039, -0.0058, -0.0025, ..., -0.0030, -0.0131, 0.0092],[ 0.0077, -0.0068, 0.0059, ..., 0.0078, 0.0055, 0.0096],...,[-0.0038, 0.0079, -0.0186, ..., -0.0171, -0.0047, 0.0003],[-0.0056, -0.0179, 0.0017, ..., -0.0092, -0.0189, 0.0128],[ 0.0144, -0.0057, 0.0038, ..., 0.0152, -0.0043, 0.0025]],requires_grad=True)), ('fc.bias', Parameter containing:
tensor([-0.0178, -0.0058, 0.0016, 0.0112, 0.0151, -0.0164, 0.0127, 0.0060,-0.0175, -0.0156], requires_grad=True))]进程已结束,退出代码为 0
设备移动统一管理
model.to('cuda') # 所有参数和子模块自动移至GPU
print(model.weight.is_cuda) # True
梯度计算自动启用
loss = model(x).sum()
loss.backward() # 所有注册的参数自动计算梯度
print(model.weight.grad is not None) # True
有些参数是需要手动注册的,才能实现自动的管理:参数/模块是通过列表或字典动态生成的,需要手动注册,在__init__部分进行注册。
class DynamicModel(nn.Module):def __init__(self):super().__init__()self.params_list = nn.ParameterList([nn.Parameter(torch.randn(10)) for _ in range(5)]) # 自动注册self.params_dict = nn.ModuleDict({'p1': nn.Linear(10, 5)}) # 自动注册# 普通Python容器内的参数需手动注册self.custom_list = [nn.Parameter(torch.randn(10))]for i, param in enumerate(self.custom_list):self.register_parameter(f'custom_{i}', param) # 手动注册
(2)模块嵌套
新构建的网络模型要在此处进行定义。
class ComplexModel(nn.Module):def __init__(self):super().__init__()self.conv_block = nn.Sequential(nn.Conv2d(3, 16, 3),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Linear(16*13*13, 10) # 假设输入图像为28x28def forward(self, x):x = self.conv_block(x)x = x.view(x.size(0), -1) # 展平return self.classifier(x)
(3)模型保存与加载
# 保存
torch.save(model.state_dict(), 'model.pth')# 加载
new_model = MyModel()
new_model.load_state_dict(torch.load('model.pth'))
(4)钩子(Hooks)
可进行调试或特征提取。
def forward_hook(module, input, output):print(f"Layer {module.__class__.__name__} output shape: {output.shape}")model.conv_block.register_forward_hook(forward_hook) # 注册钩子
3. 注意事项
不要直接调用 forward(),
应该用 model(x)
(PyTorch 会自动处理钩子和梯度)。模块命名唯一,子模块名称不能重复(如两个 self.fc
会覆盖)。