当前位置: 首页 > news >正文

Pytorch实现感知器并实现分类动画

这个实现包含以下关键部分:

  1. 数据生成:使用用户提供的函数生成两类可线性分离的数据点。

  2. 感知机模型

    • 一个线性层接收二维输入并输出一个值
    • 不使用激活函数(原始感知机形式)
    • 使用均方误差损失函数(MSE)和随机梯度下降优化器
  3. 动态可视化

    • 使用 matplotlib 的 FuncAnimation 创建动画
    • 每帧更新显示当前决策边界和损失值
    • 数据点根据真实标签着色(蓝色为 - 1,红色为 1)
    • 绿色线表示当前感知机的决策边界

运行代码后,你将看到一个动画展示感知机如何逐步学习区分两类数据的决策边界。随着训练的进行,决策边界会不断调整,直到能够正确分离两个类别。

 

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation# 数据生成函数(保持与用户提供的一致)
def generate_data():np.random.seed(0)class_1 = np.random.randn(100, 2) + np.array([2, 2])class_2 = np.random.randn(100, 2) + np.array([-2, -2])labels_1 = np.ones((100, 1))labels_2 = -np.ones((100, 1))data = np.vstack((class_1, class_2))labels = np.vstack((labels_1, labels_2))return torch.Tensor(data), torch.Tensor(labels)# 感知机模型
class Perceptron(nn.Module):def __init__(self):super(Perceptron, self).__init__()self.linear = nn.Linear(2, 1)  # 二维输入,一维输出def forward(self, x):return self.linear(x)# 训练和可视化函数
def train_and_visualize():# 生成数据X, y = generate_data()# 创建模型、损失函数和优化器model = Perceptron()criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 设置图形fig, ax = plt.subplots(figsize=(10, 8))scatter = ax.scatter(X[:, 0], X[:, 1], c=y.numpy().flatten(), cmap='coolwarm', alpha=0.7)line, = ax.plot([], [], 'g-', lw=2)ax.set_xlim(-6, 6)ax.set_ylim(-6, 6)ax.set_title('Perceptron Classification')# 初始化线def init():line.set_data([], [])return line,# 更新函数def update(frame):# 训练一步optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()# 获取当前权重和偏置w1, w2 = model.linear.weight.data[0]b = model.linear.bias.data[0]# 计算决策边界x_vals = np.linspace(-6, 6, 100)y_vals = -(w1 * x_vals + b) / w2# 更新线line.set_data(x_vals, y_vals)ax.set_title(f'Perceptron Classification (Epoch {frame + 1}, Loss: {loss.item():.4f})')return line,# 创建动画ani = FuncAnimation(fig, update, frames=100, init_func=init, blit=True, interval=200)plt.show()return ani# 运行训练和可视化
if __name__ == "__main__":animation = train_and_visualize()

http://www.xdnf.cn/news/1115641.html

相关文章:

  • JAVA并发——什么是Java的原子性、可见性和有序性
  • git实操
  • composer如何安装以及举例在PHP项目中使用Composer安装TCPDF库-优雅草卓伊凡
  • 【基础算法】倍增
  • 【开源项目】拆解机器学习全流程:一份GitHub手册的工程实践指南
  • 从儿童涂鸦到想象力视频:AI如何重塑“亲子创作”市场?
  • ABP VNext + 多级缓存架构:本地 + Redis + CDN
  • Linux的 iproute2 配置:以太网(Ethernet)、绑定(Bond)、虚拟局域网(VLAN)、网桥(Bridge)笔记250713
  • Prometheus 第一篇:快速上手
  • Vue配置特性(ref、props、混入、插件与作用域样式)
  • 第三章-提示词-解锁Prompt提示词工程核销逻辑,开启高效AI交互(10/36)
  • Linux|服务器|二进制部署nacos(不是集群,单实例)(2025了,不允许还有人不会部署nacos)
  • 学习C++、QT---23(QT中QFileDialog库实现文件选择框打开、保存讲解)
  • DVWA靶场通关笔记-XSS DOM(Medium级别)
  • 教程:如何查看浏览器扩展程序的源码
  • 飞算 JavaAI 智能编程助手:颠覆编程旧模式,重构开发生态
  • 闲庭信步使用图像验证平台加速FPGA的开发:第十三课——图像浮雕效果的FPGA实现
  • JAVA生成PDF(itextpdf)
  • 互联网大厂Java面试:从Spring Boot到微服务的场景应用
  • HTML 初体验
  • HarmonyOS组件/模板集成创新活动-元服务小云体重管理引入案例(步骤条UI组件)
  • HarmonyOS组件/模板集成创新活动-开发者工具箱
  • 【设计模式】备忘录模式(标记(Token)模式)
  • 为什么玩游戏用UDP,看网页用TCP?
  • 融合开源AI大模型与MarTech:AI智能名片与S2B2C商城小程序源码赋能数字化营销新生态
  • 【QT】使用QSS进行界面美化
  • 【Linux | 网络】应用层
  • Rust赋能文心大模型4.5智能开发
  • Leetcode 3615. Longest Palindromic Path in Graph
  • 操作系统-第四章存储器管理和第五章设备管理-知识点整理(知识点学习 / 期末复习 / 面试 / 笔试)