当前位置: 首页 > java >正文

深度学习3.7 softmax回归的简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.7.1 初始化模型参数

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

3.7.2 重新审视Softmax的实现

loss = nn.CrossEntropyLoss(reduction='none')

3.7.3 优化算法

# 在这里,我们(使用学习率为0.1的小批量随机梯度下降作为优化算法)
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

3.7.4 训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述

3.7.5 预测

batch_size = 256 #迭代器批量
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)def predict_ch3(net, test_iter, n=6):  """Predict labels (defined in Chapter 3)."""for X, y in test_iter:  # 获取第一批测试数据breaktrues = d2l.get_fashion_mnist_labels(y)  # 真实标签转文本preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))  # 预测标签转文本titles = [true +'\n' + pred for true, pred in zip(trues, preds)]  # 组合标签d2l.show_images(d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])  # 可视化predict_ch3(net, test_iter)

在这里插入图片描述

http://www.xdnf.cn/news/1627.html

相关文章:

  • Java面试:从Spring Boot到微服务的全面考核
  • sysstat介绍以及交叉编译
  • 【Redis】 Redis中常见的数据类型(二)
  • 如何解决PyQt从主窗口打开新窗口时出现闪退的问题
  • 逐步了解蓝牙 LE 配对(物联网网络安全)
  • 2024ICPC网络赛第一场题解
  • vue2如何二次封装表单控件如input, select等
  • Excel处理控件Aspose.Cells教程:使用 Python 在 Excel 中进行数据验
  • Diffusion inversion后的latent code与标准的高斯随机噪音不一样
  • 手机访问电脑端Nginx服务器配置方式
  • 新规!专利优先审查,每个申请主体每月推荐不超过2件。
  • 配置 C/C++ 语言智能感知(IntelliSense)的 c_cpp_properties.json 文件内容
  • 【k8s】KubeProxy 的三种工作模式——Userspace、iptables 、 IPVS
  • Maxscale实现Mysql的读写分离
  • 第七届能源系统与电气电力国际学术会议(ICESEP 2025)
  • 力扣热题100题解(c++)—矩阵
  • 碰一碰发视频源码文案功能,支持OEM
  • 扩散模型(Diffusion Model)详解:原理、发展与应用
  • VS Code扩张安装目录
  • CSS element-ui Icon Unicode 编码引用
  • websocket
  • 什么是 YAML:技术特性、应用场景与实践指南
  • 深入探索Spark-Streaming:从Kafka数据源创建DStream
  • CPT204 Advanced Obejct-Oriented Programming 高级面向对象编程 Pt.8 排序算法
  • 算法设计与分析(基础)
  • JetBrains GoLang IDE无限重置试用期,适用最新2025版
  • CentOS系统中MySQL安装步骤分享
  • 计算机图形学实践:结合Qt和OpenGL实现绘制彩色三角形
  • 硬件知识点-----SPI串联电阻、振铃、过冲
  • python的mtcnn检测图片中的人脸并标框