当前位置: 首页 > news >正文

python学习打卡day33

DAY 33 简单的神经网络

知识点回顾:

  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程

预处理补充:

注意事项:

1. 分类任务中,若标签是整数(如 0/1/2 类别),需转为long类型(对应 PyTorch 的torch.long),否则交叉熵损失函数会报错。

2. 回归任务中,标签需转为float类型(如torch.float32)。

作业:今日的代码,要做到能够手敲。这已经是最简单最基础的版本了。

import pandas as pd 
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
iris=load_iris()
X=iris.data
y=iris.target
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
scaler=MinMaxScaler()
X_train=scaler.fit_transform(X_train)
X_test=scaler.transform(X_test)
X_train=torch.FloatTensor(X_train)
X_test=torch.FloatTensor(X_test)
y_train=torch.LongTensor(y_train)
y_test=torch.LongTensor(y_test)class MLP(nn.Module):def __init__(self):super(MLP,self).__init__()self.fc1=nn.Linear(4,10)self.relu=nn.ReLU()self.fc2=nn.Linear(10,3)def forward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)return out
model=MLP()criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.01)num_epochs=20000
losses=[]
for epoch in range(num_epochs):outputs=model.forward(X_train)loss=criterion(outputs,y_train)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if(epoch+1)%100==0:print(f'Epoch[{epoch+1}/{num_epochs}],loss:{loss.item():.4f}')plt.plot(range(num_epochs), losses)#()内对应的是X轴和Y轴的数据
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

 

使用Adam优化器不到2000轮就收敛了,且损失要比SGD小

@浙大疏锦行

http://www.xdnf.cn/news/597709.html

相关文章:

  • 等离子体隐身技术和小型等离子体防御装置设计
  • 军事目标系列之迷彩作战人员检测数据集VOC+YOLO格式2755张1类别
  • C#中WSDL文件引用问题
  • 【接近平均分配箱子数量】2022-1-23
  • uni 常用api
  • 学习STC51单片机11(芯片为STC89C52RC)
  • 嵌入式软件架构规范之 - 分层设计
  • Linux终端输入有80个字符的限制处理
  • 【com.unity3d.player.UnityPlayer介绍】
  • Spring IoC 和 AOP -- 核心原理与高频面试题解析
  • 单测覆盖率和通过率的稳定性问题,以及POM文件依赖包版本一致性的挑战
  • 位运算及其算法
  • 解决wsl没代理的问题
  • 第4周_作业题_逐步构建你的深度神经网络
  • 论文解读 | 《药用真菌桑黄通过内质网应激 - 线粒体损伤诱导人宫颈癌细胞凋亡》
  • 从JDK 17到JDK 21:Java核心特性概述
  • Python之web错误处理与异常捕获
  • 【人工智能】从零到一:大模型应用开发的奇幻旅程
  • 【修改提问代码-筹款】2022-1-29
  • Qwen2.5-VL技术解读和文档解析可行性验证
  • Any类(C++17类型擦除,也称上帝类)
  • ORA-00313 ORA-00312 ORA-27037 redo被删除后重建
  • 如何顺利地将应用程序从 Android 转移到Android
  • SpringCloud (3) 配置中心
  • vue项目的dist在nginx部署后报错Uncaught SyntaxError
  • 技术篇-2.2.JAVA应用场景及开发工具安装
  • Spring Boot 注解 @ConditionalOnMissingBean是什么
  • 嵌入式开发学习日志(linux系统编程--io文件偏移函数(3)和目录)Day26
  • 文件IO操作、目录操作
  • 【leetcode】3355. 零数组变换Ⅰ