当前位置: 首页 > java >正文

PyTorch——线性层及其他层介绍(6)


线性层

前面1,1,1是你想要的,后面我们不知道这个值是多少,取-1让Python自己计算


import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader# 加载CIFAR-10测试数据集并转换为Tensor格式
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)# 创建数据加载器,每批次包含64个样本
dataloader = DataLoader(dataset, batch_size=64)# 定义神经网络模型TY
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 定义全连接层:输入维度196608,输出维度10(对应10个类别)self.Linear1 = Linear(196608, 10)def forward(self, input):# 前向传播:将输入数据通过全连接层output = self.Linear1(input)return output# 实例化模型
ty = TY()# 遍历数据加载器中的每个批次
for data in dataloader:# 获取图像数据和对应的标签imgs, target = data# 打印原始图像张量形状:[批次大小, 通道数, 高度, 宽度]print(imgs.shape)# 将图像张量展平为一维向量# 注意:此处reshape参数(1,1,1,-1)会导致维度错误,正确应为(-1, 196608)output = torch.reshape(imgs, (1, 1, 1, -1))# 打印展平后的张量形状print(output.shape)# 将展平后的数据输入模型output = ty(output)# 打印模型输出形状:[批次大小, 类别数]print(output.shape)


另一种表达  flatten展平

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset,batch_size=64)class TY(nn.Module):def __init__(self):super(TY,self).__init__()self.Linear1 = Linear(196608,10)def forward(self,input):output = self.Linear1(input)return outputty = TY()for data in dataloader:imgs,target = dataprint(imgs.shape)output=torch.flatten(imgs)print(output.shape)output = ty(output)print(output.shape)

http://www.xdnf.cn/news/10725.html

相关文章:

  • 【HarmonyOS 5】鸿蒙APP使用【团结引擎Unity】开发的案例教程
  • LEAP模型能源需求/供应预测、能源平衡表核算、空气污染物排放预测、碳排放建模预测、成本效益分析、电力系统优化
  • 【macbook】触控板手势
  • 数据解析:一文掌握Python库 lxml 的详细使用(处理XML和HTML的高性能库)
  • 基于 COM 的 XML 解析技术(MSXML) 的总结
  • CSS设置移动端页面底部安全距离
  • 【Hot 100】279. 完全平方数
  • PopupImageMenuItem 无响应
  • AXURE-动态面板
  • 最优包含--字符串dp
  • 解锁技术文档撰写秘籍:从混沌到清晰的蜕变之旅
  • 帝可得 - 策略管理
  • 利用Python 进行自动化操作: Pyautogui 库
  • SQL注入漏洞-上篇
  • 正点原子lwIP协议的学习笔记
  • xmake的简易学习
  • CppCon 2014 学习:Cross platform GUID association with types
  • 蛋白质设计软件LigandMPNN介绍
  • 宇树科技更名“股份有限公司”深度解析:机器人企业IPO前奏与资本化路径
  • R1-Searcher++新突破!强化学习如何赋能大模型动态知识获取?
  • 职坐标IT培训:嵌入式开发C语言/硬件/RTOS路径
  • 时代星光推出战狼W60智能运载无人机,主要性能超市场同类产品一倍!
  • NLP实战(5):基于LSTM的电影评论情感分析模型研究
  • BugKu Web渗透之源代码
  • C++ stl容器之string(字符串类)
  • .NET 原生驾驭 AI 新基建实战系列(一):向量数据库的应用与畅想
  • 利用 Scrapy 构建高效网页爬虫:框架解析与实战流程
  • 2022年 国内税务年鉴PDF电子版Excel
  • centos安装locate(快速查找linux文件)
  • 【QT】QString 与QString区别