当前位置: 首页 > news >正文

神经网络——线性层

在机器学习中,线性层(Linear Layer) 是一种基础的神经网络组件,也称为全连接层(Fully Connected Layer) 或密集层(Dense Layer)

其严格的数学定义为:对输入数据执行线性变换,生成输出向量。

具体形式为:
                Y=XW+b  
其中:

  • X 是输入张量,通常形状为 [批次大小, 输入维度]

  • W 是可学习的权重矩阵,形状为 [输入维度, 输出维度]

  • b 是可学习的偏置向量,形状为 [输出维度]

  • Y 是输出张量,形状为 [批次大小, 输出维度]

核心特性

  1. 参数共享:同一层内的所有输入神经元都通过权重矩阵 W 与输出神经元相连,权重在整个输入空间中共享。

  2. 线性变换:仅能表示线性函数,因此通常与非线性激活函数(如 ReLU)组合使用,以增强模型表达能力。

  3. 特征投影:本质上是将输入特征投影到新的特征空间,输出维度决定了新空间的维度。

线性网络: 

 


 

参数

  • in_features (int) – 每个输入样本的大小

  • out_features (int) – 每个输出样本的大小

  • bias (bool) – 如果设置为 False,该层将不学习加性偏置。默认值: True

 

 代码举例

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../torchvision_dataset", train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset, batch_size=64)class MyModule(nn.Module):def __init__(self):super().__init__()"""下文:展开前:torch.Size([64, 3, 32, 32])展开后:torch.Size([1, 1, 1, 196608])"""self.linear = Linear(196608, 10)def forward(self, input):output = self.linear(input)return outputmodule = MyModule()for data in dataloader:imgs, targets = dataprint("原本图像尺寸", imgs.shape)# 把二维图片展开成一维的# imgs=torch.reshape(imgs,(1,1,1,-1))imgs = torch.flatten(imgs)print("展平后图像尺寸", imgs.shape)output = module(imgs)print("经过线性层处理后图像尺寸", output.shape)

 

http://www.xdnf.cn/news/1164853.html

相关文章:

  • 【c++】leetcode5 最长回文子串
  • 蚂蚁数科AI数据产业基地正式投产,携手苏州推进AI产业落地
  • 奥比中光深度相机开发
  • 感知机-梯度下降法
  • 141 个 LangChain4j Maven 组件分类解析、多场景实战攻略
  • 一个月掌握数据结构与算法:高效学习计划
  • hot100回归复习(算法总结1-38)
  • 零拷贝技术(Zero-Copy)
  • 网络协议(四)网络层 路由协议
  • C++基于libmodbus库实现modbus TCP/RTU通信
  • 大模型——上下文工程 (Context Engineering) – 现代 AI 系统的架构基础
  • C# 实现:动态规划解决 0/1 背包问题
  • iOS开发 Swift 速记2:三种集合类型 Array Set Dictionary
  • OCR 身份识别:让身份信息录入场景更高效安全
  • 如何使用终端查看任意Ubuntu的版本信息
  • 用 Three.js 实现 PlayCanvas 风格 PBR 材质教程(第二篇):核心参数与光照模型
  • DBSCAN聚类算法
  • OpenAI Codex CLI与 Google Gemini CLI 比较
  • 关于java8里边Collectors.toMap()的空限制
  • 泛型:C#中的类型抽象艺术
  • Android NDK ffmpeg 音视频开发实战
  • 数据结构 之 【排序】(直接插入排序、希尔排序)
  • 【C++】list的模拟实现
  • 音视频学习(四十二):H264帧间压缩技术
  • 周志华《机器学习导论》第13章 半监督学习
  • [深度学习] 大模型学习3上-模型训练与微调
  • 机器学习初学者理论初解
  • MySQL:表的增删查改
  • 基于VSCode的nRF52840开发环境搭建
  • C++高性能日志库spdlog介绍