李沐动手学深度学习Pytorch-v2笔记【08线性回归+基础优化算法】2
文章目录
- 线性回归的简介实现
- **通过使用深度学习框架来简洁实现 线性回归模型 生成数据集**
- **使用框架的预定好的层**
- **初始化模型参数**
- **计算均方误差使用的是MESLoss类也称平方范数**
- **实例化SGD示实例**:
- **训练过程**
线性回归的简介实现
通过使用深度学习框架来简洁实现 线性回归模型 生成数据集
import torch
import numpy as np
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
def load_array(data_arrays, batch_size, is_train = True):"构造一个Pytorch数据迭代器"dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle = is_train)
batch_size = 10
data_iter = load_array((features, labels),batch_size)next(iter(data_iter))
data.TensorDataset :是 PyTorch 提供的一个类,用于将多个张量封装为一个数据集。
data_arrays :是解包操作,假设 data_arrays
是 (features, labels)
,则等价于 data.TensorDataset(features, labels)
。
DataLoader
:提供批次加载、数据打乱和多线程支持。
next(iter(data_iter))
:
iter(data_iter)
将 DataLoader
转换为迭代器。
next()
获取下一个批次(第一次调用时是第一批数据)。
使用框架的预定好的层
#"nn"是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))
初始化模型参数
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
_
:表示写入
normal
: 表示正态分布
计算均方误差使用的是MESLoss类也称平方范数
loss = nn.MSELoss()
实例化SGD示实例:
trainer = torch.optim.SGD(net.parameters(),lr = 0.03)
torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
params
(必须参数): 这是一个包含了需要优化的参数(张量)的迭代器,例如模型的参数 model.parameters()
。
lr
(必须参数): 学习率(learning rate)。它是一个正数,控制每次参数更新的步长。较小的学习率会导致收敛较慢,较大的学习率可能导致震荡或无法收敛。
momentum
(默认值为 0
): 动量(momentum
)是一个用于加速 SGD
收敛的参数。它引入了上一步梯度的指数加权平均。通常设置在 0
到 1
之间。当 momentum
大于 0
时,算法在更新时会考虑之前的梯度,有助于加速收敛。
dampening
(默认值为 0
): 阻尼项,用于减缓动量的速度。在某些情况下,为了防止动量项引起的震荡,可以设置一个小的 dampening
值。
weight_decay
(默认值为 0
): 权重衰减,也称为 L2 正则化项。它用于控制参数的幅度,以防止过拟合。通常设置为一个小的正数。
nesterov
(默认值为 False
): Nesterov
动量。当设置为 True
时,采用 Nestero
v 动量更新规则。Nesterov
动量在梯度更新之前先进行一次预测,然后在计算梯度更新时使用这个预测。
训练过程
num_epochs = 3
for epoch in rangepochs:for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()train_step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {1:f}')