当前位置: 首页 > news >正文

Day33 MLP神经网络的训练

@浙大疏锦行

知识点回顾:
  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程

前置知识回顾:

  1. 梯度下降的思想:梯度下降是一种优化算法,用于最小化损失函数。其核心思想是通过计算损失函数的梯度(即函数在各点上的偏导数向量),沿着梯度的反方向调整网络参数,从而逐步接近损失函数的最小值。梯度指向函数值增长最快的方向,因此负梯度方向就是函数值下降最快的方向。
  2. 激活函数的作用:激活函数是神经网络中的重要组件,用于引入非线性因素,使神经网络能够学习和表示复杂的非线性关系。如果没有激活函数,多层神经网络就退化为线性模型,无法处理复杂任务。常见的激活函数包括 ReLU、Sigmoid 和 Tanh 等。
  3. 损失函数的作用:损失函数用于衡量模型预测值与真实值之间的差异。它是训练过程中需要最小化的函数,指导模型如何调整参数以提高预测准确性。损失函数的选择取决于具体任务,如分类问题常用交叉熵损失,回归问题常用均方误差损失。
  4. 优化器:优化器是训练神经网络的关键组件,负责根据损失函数的梯度信息更新模型参数(权重和偏置),以最小化损失函数。不同的优化器有不同的更新策略,影响模型的收敛速度和最终性能。常见的优化器包括 SGD、Adam、RMSProp 等。
  5. 神经网络的概念:神经网络是一种模拟生物神经网络结构的计算模型,由大量相互连接的节点(神经元)组成。每个神经元接收输入,进行加权求和并通过激活函数产生输出。神经网络通过多层结构能够学习复杂的非线性关系,用于解决分类、回归等各种机器学习任务。

神经网络由于内部比较灵活,所以封装的比较浅,可以对模型做非常多的改进,而不像机器学习三行代码固定。与机器学习中最大的不同。

数据准备

# 仍然用4特征,3分类的鸢尾花数据集作为我们今天的数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征数据
y = iris.target  # 标签数据
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 打印下尺寸
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test) #确保训练集和测试集是相同的缩放# 将数据转换为 PyTorch 张量,因为 PyTorch 使用张量进行训练
# y_train和y_test是整数,所以需要转化为long类型,如果是float32,会输出1.0 0.0
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

模型架构定义

定义一个简单的全连接神经网络模型,包含一个输入层、一个隐藏层和一个输出层。

定义层数+定义前向传播顺序

import torch
import torch.nn as nn
import torch.optim as optimclass MLP(nn.Module): # 定义一个多层感知机(MLP)模型,继承父类nn.Moduledef __init__(self): # 初始化函数super(MLP, self).__init__() # 调用父类的初始化函数# 前三行是八股文,后面的是自定义的self.fc1 = nn.Linear(4, 10)  # 输入层到隐藏层self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)  # 隐藏层到输出层
# 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型
model = MLP()

模型训练,定义损失函数和优化器

criterion = nn.CrossEntropyLoss()# 使用随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 使用自适应学习率的化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 20000 # 训练的轮数# 用于存储每个 epoch 的损失值
losses = []for epoch in range(num_epochs): # range是从0开始,所以epoch是从0开始# 前向传播outputs = model.forward(X_train)   # 显式调用forward函数# outputs = model(X_train)  # 常见写法隐式调用forward函数,其实是用了model类的__call__方法loss = criterion(outputs, y_train) # output是模型预测值,y_train是真实标签# 反向传播和优化optimizer.zero_grad() #梯度清零,因为PyTorch会累积梯度,所以每次迭代需要清零,梯度累计是那种小的bitchsize模拟大的bitchsizeloss.backward() # 反向传播计算梯度optimizer.step() # 更新参数# 记录损失值losses.append(loss.item())# 打印训练信息if (epoch + 1) % 100 == 0: # range是从0开始,所以epoch+1是从当前epoch开始,每100个epoch打印一次print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

训练时,要关注每个epoch训练完后测试集的表现:测试集的loss和准确度

最后进行结果可视化

import matplotlib.pyplot as plt
# 可视化损失曲线
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

http://www.xdnf.cn/news/1352719.html

相关文章:

  • 「ECG信号处理——(24)基于ECG和EEG信号的多模态融合疲劳分析」2025年8月23日
  • 前端 H5分片上传 vue实现大文件
  • 【卫星通信】超低码率语音编码ULBC:EnCodec神经音频编解码器架构深度解析
  • piclist+gitee操作指南
  • 【Day 11】238.除自身以外数组的乘积
  • Transformer核心概念I-token
  • SpringBoot 快速上手:从环境搭建到 HelloWorld 实战
  • Excel 条件高亮工具,秒高亮显示符合筛选条件的行数据
  • 「数据获取」《中国能源统计年鉴》(1986-2023)(获取方式看绑定的资源)
  • 蓝桥杯算法之基础知识(2)——Python赛道
  • 【51单片机学习】直流电机驱动(PWM)、AD/DA、红外遥控(外部中断)
  • mmdetection:记录算法训练配置文件
  • A Large Scale Synthetic Graph Dataset Generation Framework的学习笔记
  • Mysql EXPLAIN详解:从底层原理到性能优化实战
  • 如何在Ubuntu中删除或修改已有的IP地址设置?
  • C语言---数据类型
  • PyTorch生成式人工智能——VQ-VAE详解与实现
  • TypeScript 的泛型(Generics)作用理解
  • Kafka 概念与概述
  • 在TencentOS3上部署OpenTenBase:从入门到实战的完整指南
  • 【Java学习笔记】18.反射与注解的应用
  • 遥感机器学习入门实战教程|Sklearn案例⑧:评估指标(metrics)全解析
  • tcpdump命令打印抓包信息
  • 【golang】ORM框架操作数据库
  • 2-5.Python 编码基础 - 键盘输入
  • STM32CubeIDE V1.9.0下载资源链接
  • 醋酸镨:催化剂领域的璀璨新星
  • LangChain4J-基础(整合Spring、RAG、MCP、向量数据库、提示词、流式输出)
  • 信贷模型域——信贷获客模型(获客模型)
  • 温度对直线导轨的性能有哪些影响?