当前位置: 首页 > news >正文

基础神经网络模型搭建

nn 包提供通用深度学习网络的模块集合,接收输入张量,计算输出张量,并保存权重。通常使用两种途径搭建 PyTorch 中的模型:nn.Sequential和 nn.Module。

nn.Sequential通过线性层有序组合搭建模型;nn.Module通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

目录

搭建线性层

通过nn.Sequential搭建

通过nn.Module搭建

获取模型摘要


搭建线性层

使用 nn 包搭建线性层。线性层接收 64*1000 维的输入,保存 1000*100 维的权重,并计算 64*100 维的输出。

import torch
from torch import nn
input_tensor = torch.randn(64, 1000)
linear_layer = nn.Linear(1000, 100)
output = linear_layer(input_tensor)
print(input_tensor.size())
print(output.size())

通过nn.Sequential搭建

考虑一个两层的神经网络,四个节点作为输入,五个节点在隐藏层,一个节点作为输出

from torch import nn
model = nn.Sequential(nn.Linear(4, 5),nn.ReLU(),nn.Linear(5, 1),
)
print(model)

通过nn.Module搭建

在 PyTorch 中搭建模型的另一种方法是对 nn.Module 类进行子类化,通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

考虑两个卷积层和两个完全连接层搭建的模型:

import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass

定义__init__ 函数和forward 函数

def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

重写两个类函数并打印模型

重写:子类中实现一个与父类的成员函数原型完全相同的函数

Net.__init__ = __init__
Net.forward = forward
model = Net()
print(model)

 查看模型位置

print(next(model.parameters()).device)

 

将模型移动至CUDA设备 

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

获取模型摘要

借助torchsummary包查获取模型摘要

pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

http://www.xdnf.cn/news/1159885.html

相关文章:

  • 【数据结构】栈和队列(接口超完整)
  • jQuery 插件
  • 本地部署 Claude 大语言模型的完整实践指南
  • 创建一个触发csrf的恶意html
  • 创新几何解谜游戏,挑战空间思维极限
  • ollama基本配置
  • 玄机——第六章 流量特征分析-蚂蚁爱上树
  • 2025最新 PostgreSQL17 安装及配置(Windows原生版)
  • 【Go语言-Day 22】解耦与多态的基石:深入理解 Go 接口 (Interface) 的核心概念
  • [硬件电路-59]:电源:电子存储的仓库,电能的发生地,电场的动力场所
  • 手写tomcat
  • API获取及调用(以豆包为例实现图像分析)
  • 用 Jetpack Compose 写 Android 的 “Hello World”
  • SSE和WebSocket区别到底是什么
  • linux shell从入门到精通(一)——为什么要学习Linux Shell
  • MongoDB多节点集群原理 -- 复制集
  • 《杜甫传》读书笔记与经典摘要(一)
  • 人工智能之数学基础:随机实验、样本空间、随机事件
  • 【算法训练营Day15】二叉树part5
  • LVS-----TUN模式配置
  • 【LeetCode刷题指南】--反转链表,链表的中间结点,合并两个有序链表
  • 【原创】微信小程序添加TDesign组件
  • tabBar设置底部菜单选项、iconfont图标(图片)库、模拟京东app的底部导航栏
  • 零基础学习性能测试第三章:执行性能测试
  • Windows CMD(命令提示符)中最常用的命令汇总和实战示例
  • 30天打牢数模基础-SVM讲解
  • Python 单例模式几种实现方式
  • Dify 1.6 安装与踩坑记录(Docker 方式)
  • ZooKeeper学习专栏(二):深入 Watch 机制与会话管理
  • 【单片机外部中断实验修改动态数码管0-99】2022-5-22