当前位置: 首页 > news >正文

python训练营第33天

MLP神经网络的训练

知识点回顾:

  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt# 检查CUDA可用性并输出相关信息
if torch.cuda.is_available():print("CUDA可用!")device_count = torch.cuda.device_count()print(f"可用的CUDA设备数量:{device_count}")current_device = torch.cuda.current_device()print(f"当前使用的CUDA设备索引:{current_device}")device_name = torch.cuda.get_device_name(current_device)print(f"当前CUDA设备的名称:{device_name}")cuda_version = torch.version.cudaprint(f"CUDA版本:{cuda_version}")
else:print("CUDA不可用。")# 加载并准备Iris数据集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)# 数据标准化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)# 定义多层感知机模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 初始化模型、损失函数和优化器
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 添加了学习率参数# 训练模型
num_epochs = 20000
losses = []for epoch in range(num_epochs):outputs = model.forward(X_train)loss = criterion(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 可视化训练损失
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

输出结果:

(120, 4)
(120,)  
(30, 4) 
(30,)   
Epoch [100/20000], Loss: 1.0980
Epoch [200/20000], Loss: 1.0692
Epoch [300/20000], Loss: 1.0398
Epoch [400/20000], Loss: 1.0060
Epoch [500/20000], Loss: 0.9665
Epoch [600/20000], Loss: 0.9214
Epoch [700/20000], Loss: 0.8731
Epoch [800/20000], Loss: 0.8241
Epoch [900/20000], Loss: 0.7777
Epoch [1000/20000], Loss: 0.7360
Epoch [1100/20000], Loss: 0.6991
Epoch [1200/20000], Loss: 0.6666
Epoch [1300/20000], Loss: 0.6380
Epoch [1400/20000], Loss: 0.6127
Epoch [1500/20000], Loss: 0.5903
Epoch [1600/20000], Loss: 0.5703
Epoch [1700/20000], Loss: 0.5523
Epoch [1800/20000], Loss: 0.5360
Epoch [1900/20000], Loss: 0.5211
Epoch [2000/20000], Loss: 0.5074
Epoch [2100/20000], Loss: 0.4948
Epoch [2200/20000], Loss: 0.4830
Epoch [2300/20000], Loss: 0.4720
Epoch [2400/20000], Loss: 0.4616
Epoch [2500/20000], Loss: 0.4518
Epoch [2600/20000], Loss: 0.4425
Epoch [2700/20000], Loss: 0.4336
Epoch [2800/20000], Loss: 0.4251
Epoch [2900/20000], Loss: 0.4168
Epoch [3000/20000], Loss: 0.4088
Epoch [3100/20000], Loss: 0.4010
Epoch [3200/20000], Loss: 0.3935
Epoch [3300/20000], Loss: 0.3860
Epoch [3400/20000], Loss: 0.3787
Epoch [3500/20000], Loss: 0.3716
Epoch [3600/20000], Loss: 0.3645
Epoch [3700/20000], Loss: 0.3576
Epoch [3800/20000], Loss: 0.3508
Epoch [3900/20000], Loss: 0.3440
Epoch [4000/20000], Loss: 0.3374
Epoch [4100/20000], Loss: 0.3308
Epoch [4200/20000], Loss: 0.3243
Epoch [4300/20000], Loss: 0.3179
Epoch [4400/20000], Loss: 0.3115
Epoch [4500/20000], Loss: 0.3053
Epoch [4600/20000], Loss: 0.2991
Epoch [4700/20000], Loss: 0.2930
Epoch [4800/20000], Loss: 0.2870
Epoch [4900/20000], Loss: 0.2810
Epoch [5000/20000], Loss: 0.2752
Epoch [5100/20000], Loss: 0.2694
Epoch [5200/20000], Loss: 0.2638
Epoch [5300/20000], Loss: 0.2583
Epoch [5400/20000], Loss: 0.2530
Epoch [5500/20000], Loss: 0.2478
Epoch [5600/20000], Loss: 0.2428
Epoch [5700/20000], Loss: 0.2378
Epoch [5800/20000], Loss: 0.2330
Epoch [5900/20000], Loss: 0.2284
Epoch [6000/20000], Loss: 0.2238
Epoch [6100/20000], Loss: 0.2193
Epoch [6200/20000], Loss: 0.2150
Epoch [6300/20000], Loss: 0.2108
Epoch [6400/20000], Loss: 0.2067
Epoch [6500/20000], Loss: 0.2027
Epoch [6600/20000], Loss: 0.1989
Epoch [6700/20000], Loss: 0.1951
Epoch [6800/20000], Loss: 0.1914
Epoch [6900/20000], Loss: 0.1878
Epoch [7000/20000], Loss: 0.1844
Epoch [7100/20000], Loss: 0.1810
Epoch [7200/20000], Loss: 0.1778
Epoch [7300/20000], Loss: 0.1746
Epoch [7400/20000], Loss: 0.1716
Epoch [7500/20000], Loss: 0.1686
Epoch [7600/20000], Loss: 0.1658
Epoch [7700/20000], Loss: 0.1630
Epoch [7800/20000], Loss: 0.1604
Epoch [7900/20000], Loss: 0.1578
Epoch [8000/20000], Loss: 0.1553
Epoch [8100/20000], Loss: 0.1528
Epoch [8200/20000], Loss: 0.1505
Epoch [8300/20000], Loss: 0.1483
Epoch [8400/20000], Loss: 0.1461
Epoch [8500/20000], Loss: 0.1440
Epoch [8600/20000], Loss: 0.1420
Epoch [8700/20000], Loss: 0.1400
Epoch [8800/20000], Loss: 0.1381
Epoch [8900/20000], Loss: 0.1363
Epoch [9000/20000], Loss: 0.1345
Epoch [9100/20000], Loss: 0.1328
Epoch [9200/20000], Loss: 0.1312
Epoch [9300/20000], Loss: 0.1295
Epoch [9400/20000], Loss: 0.1280
Epoch [9500/20000], Loss: 0.1265
Epoch [9600/20000], Loss: 0.1250
Epoch [9700/20000], Loss: 0.1236
Epoch [9800/20000], Loss: 0.1222
Epoch [9900/20000], Loss: 0.1208
Epoch [10000/20000], Loss: 0.1195
Epoch [10100/20000], Loss: 0.1183
Epoch [10200/20000], Loss: 0.1170
Epoch [10300/20000], Loss: 0.1159
Epoch [10400/20000], Loss: 0.1147
Epoch [10500/20000], Loss: 0.1136
Epoch [10600/20000], Loss: 0.1125
Epoch [10700/20000], Loss: 0.1114
Epoch [10800/20000], Loss: 0.1104
Epoch [10900/20000], Loss: 0.1094
Epoch [11000/20000], Loss: 0.1084
Epoch [11100/20000], Loss: 0.1075
Epoch [11200/20000], Loss: 0.1065
Epoch [11300/20000], Loss: 0.1056
Epoch [11400/20000], Loss: 0.1048
Epoch [11500/20000], Loss: 0.1039
Epoch [11600/20000], Loss: 0.1031
Epoch [11700/20000], Loss: 0.1023
Epoch [11800/20000], Loss: 0.1015
Epoch [11900/20000], Loss: 0.1007
Epoch [12000/20000], Loss: 0.0999
Epoch [12100/20000], Loss: 0.0992
Epoch [12200/20000], Loss: 0.0985
Epoch [12300/20000], Loss: 0.0978
Epoch [12400/20000], Loss: 0.0971
Epoch [12500/20000], Loss: 0.0964
Epoch [12600/20000], Loss: 0.0958
Epoch [12700/20000], Loss: 0.0951
Epoch [12800/20000], Loss: 0.0945
Epoch [12900/20000], Loss: 0.0939
Epoch [13000/20000], Loss: 0.0933
Epoch [13100/20000], Loss: 0.0927
Epoch [13200/20000], Loss: 0.0922
Epoch [13300/20000], Loss: 0.0916
Epoch [13400/20000], Loss: 0.0910
Epoch [13500/20000], Loss: 0.0905
Epoch [13600/20000], Loss: 0.0900
Epoch [13700/20000], Loss: 0.0895
Epoch [13800/20000], Loss: 0.0890
Epoch [13900/20000], Loss: 0.0885
Epoch [14000/20000], Loss: 0.0880
Epoch [14100/20000], Loss: 0.0875
Epoch [14200/20000], Loss: 0.0871
Epoch [14300/20000], Loss: 0.0866
Epoch [14400/20000], Loss: 0.0862
Epoch [14500/20000], Loss: 0.0857
Epoch [14600/20000], Loss: 0.0853
Epoch [14700/20000], Loss: 0.0849
Epoch [14800/20000], Loss: 0.0845
Epoch [14900/20000], Loss: 0.0840
Epoch [15000/20000], Loss: 0.0837
Epoch [15100/20000], Loss: 0.0833
Epoch [15200/20000], Loss: 0.0829
Epoch [15300/20000], Loss: 0.0825
Epoch [15400/20000], Loss: 0.0821
Epoch [15500/20000], Loss: 0.0818
Epoch [15600/20000], Loss: 0.0814
Epoch [15700/20000], Loss: 0.0811
Epoch [15800/20000], Loss: 0.0807
Epoch [15900/20000], Loss: 0.0804
Epoch [16000/20000], Loss: 0.0800
Epoch [16100/20000], Loss: 0.0797
Epoch [16200/20000], Loss: 0.0794
Epoch [16300/20000], Loss: 0.0791
Epoch [16400/20000], Loss: 0.0788
Epoch [16500/20000], Loss: 0.0785
Epoch [16600/20000], Loss: 0.0782
Epoch [16700/20000], Loss: 0.0779
Epoch [16800/20000], Loss: 0.0776
Epoch [16900/20000], Loss: 0.0773
Epoch [17000/20000], Loss: 0.0770
Epoch [17100/20000], Loss: 0.0767
Epoch [17200/20000], Loss: 0.0765
Epoch [17300/20000], Loss: 0.0762
Epoch [17400/20000], Loss: 0.0759
Epoch [17500/20000], Loss: 0.0757
Epoch [17600/20000], Loss: 0.0754
Epoch [17700/20000], Loss: 0.0751
Epoch [17800/20000], Loss: 0.0749
Epoch [17900/20000], Loss: 0.0747
Epoch [18000/20000], Loss: 0.0744
Epoch [18100/20000], Loss: 0.0742
Epoch [18200/20000], Loss: 0.0739
Epoch [18300/20000], Loss: 0.0737
Epoch [18400/20000], Loss: 0.0735
Epoch [18500/20000], Loss: 0.0733
Epoch [18600/20000], Loss: 0.0730
Epoch [18700/20000], Loss: 0.0728
Epoch [18800/20000], Loss: 0.0726
Epoch [18900/20000], Loss: 0.0724
Epoch [19000/20000], Loss: 0.0722
Epoch [19100/20000], Loss: 0.0720
Epoch [19200/20000], Loss: 0.0718
Epoch [19300/20000], Loss: 0.0716
Epoch [19400/20000], Loss: 0.0714
Epoch [19500/20000], Loss: 0.0712
Epoch [19600/20000], Loss: 0.0710
Epoch [19700/20000], Loss: 0.0708
Epoch [19800/20000], Loss: 0.0706
Epoch [19900/20000], Loss: 0.0704
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702

@浙大疏锦行

http://www.xdnf.cn/news/629119.html

相关文章:

  • Lesson 22 A glass envelope
  • HJ14 字符串排序【牛客网】
  • Spring AI 源码解析:Tool Calling链路调用流程及示例
  • 从法律视角看债务管理:湖北理元理律师事务所的实践探索
  • 【信息系统项目管理师】一文掌握高项常考题型-成本类计算
  • 巡礼中国西极·跨越昆仑天山 | 北斗卫星徽章护航昆仑科考
  • 神经算子项目实战:数据分析、可视化与实现全过程
  • 归一化 超全总结!!
  • leetcode hot100刷题日记——16.全排列
  • 探秘Transformer系列之(34)--- 量化基础
  • 开源轻量级语音合成和语音克隆模型:OuteTTS-1.0-0.6B
  • AWTK嵌入式图形框架开发备忘(二)
  • 【GESP真题解析】第 5 集 GESP 二级 2023 年 3 月编程题 2:百鸡问题
  • 【Python】【电网规划】基于经济与可靠性双目标的混合配电系统规划及可靠性评估
  • ShenNiusModularity项目源码学习(30:ShenNius.Admin.Mvc项目分析-15)
  • 可增添功能的鼠标右键优化工具
  • 【PINN】DeepXDE学习训练营(33)——pinn_forward-fractional_Poisson_1d.py
  • C++:共享指针unique_ptr的理解与应用
  • 每日定投40刀BTC(17)20250511 - 20250524
  • 什么是数据分析
  • Go基础语法与控制结构
  • ROS云课三分钟-破壁篇GCompris-一小部分支持Edu应用列表-2025
  • 部署n8n
  • 海思SVP_NPU开发适配
  • Python训练营---Day35
  • 哈希表原理与双散列实战指南
  • 超时处理机制设计:从TICK到回调
  • 刷leetcode hot100返航版--贪心5/23
  • Python性能优化利器:__slots__的深度解析与避坑指南
  • 《2.1.4 C语言中的整数类型及类型转换|精讲篇》