当前位置: 首页 > java >正文

《Pytorch深度学习实践》ch5-Logistic回归

                                                        ------B站《刘二大人》

1.Classification

  • 经典的分类数据集:MNIST(0 - 9)

  • 导入数据集:(路径,训练集/测试集,是否下载)
import torchvision
train_set = torchvision.datasets.MINIST(root='../dataset/mnist', train=True,  download=True)
test_set  = torchvision.datasets.MINIST(root='../dataset/mnist', train=False, download=True)

2.Sigmoid functions

  • 由于分类问题就是求概率的最大值,所以利用 S 函数将数值全部映射到 [0,1] 区间;
  • 最著名的就是这个 Logistic 函数:

  • 其它的 S 函数:

3.Logistic Regression Model

  • 就是在原函数基础上加一个 Sigmoid:

4.Loss and BCE

  • BCE:Binary Cross Entropy,二元交叉熵损失:

5.Implemetation

  • 导包:
import torch
import torch.nn.functional as F
  • 数据集:y 变为 {0,1},二分类
# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
  • 模型:F.sigmoid()函数
# 模型
class LogisticRegressionModel(torch.nn.Module): # Module 构建计算图def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 前馈y_pred = F.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel() # 实例化
  •  损失和优化器:BCELoss
# 损失函数和优化器
criterion = torch.nn.BCELoss(reduction = 'sum') # 计算损失,参数为(y_pred, y)optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 进行更新
  • 训练:
# 训练
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) # 1.前馈print(epoch, loss)optimizer.zero_grad() # 梯度清零loss.backward() # 2.反馈optimizer.step() # 3.更新

6.Result

import numpy as np
import matplotlib.pyplot as pltx = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200,1)) # 将x数组转换为PyTorch张量,并将其形状调整为列向量(200x1)
y_t = model(x_t)
y = y_t.data.numpy() # 将输出张量y_t转换为NumPy数组yplt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') # 绘制一条从x=0到x=10的红色水平线,y值为0.5
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
  • 绘图如下:

http://www.xdnf.cn/news/10837.html

相关文章:

  • 百万级临床试验数据库TrialPanorama发布!AI助力新药研发与临床评价迎来新基石
  • Rhino插件大全下载指南:解锁犀牛潜能,提升设计效率
  • C++11:unique_ptr的基本用法、使用场景和最佳使用指南
  • 利用lightgbm预测adult数据集
  • 支持TypeScript并打包为ESM/CommonJS/UMD三种格式的脚手架项目
  • MYSQL索引详解及索引优化、分析
  • Cyber Weekly #58
  • 低成本单节电池风扇解决方案WD8001
  • switch-while day6
  • Spring AOP(1)
  • 小家电外贸出口新利器:WD8001低成本风扇智能控制方案全解析
  • 模块化交互数字人系统:OpenAvatarChat,单台PC即可运行完整功能
  • sourcetree中的mercurial有什么用
  • Python实例题:Flask实现简单聊天室
  • 【PCB设计】STM32开发板——原理图设计(电源部分)
  • FLgo学习
  • leetcode46.全排列:回溯算法中元素利用的核心逻辑
  • MyBatis 一级缓存与二级缓存
  • 【Python进阶】装饰器
  • 基于白鲸优化算法的路径优化研究
  • 数字化赋能智能托育实训室课程体系
  • 工业透明材料应力缺陷难检测?OAS 软件应力双折射案例来解决
  • ADK实战-基于ollama+qwen3实现外部工具串行调用
  • 帝可得 - 运营管理APP
  • MMAD论文精读
  • day20 奇异值SVD分解
  • 线程池和数据库连接池的区别
  • 3-10单元格行、列号获取(实例:表格选与维度转换)学习笔记
  • 163MusicLyrics(歌词下载工具) v7.0
  • MDP的observations部分