当前位置: 首页 > web >正文

softmax传递函数+交叉熵损失

在多分类问题中,Softmax 函数通常与交叉熵损失函数结合使用。

Softmax 函数

Softmax 函数是一种常用的激活函数,主要用于多分类问题中。它将一个实数向量转换为概率分布,使得每个元素的值在 0 到 1 之间,且所有元素的和为 1。

Softmax 函数的数学表达式:

softmax ( z i ) = e z i ∑ j = 1 K e z j \text{softmax}(z_i) = \frac{{\rm e}^{z_i}}{\sum_{j=1}^{K} {\rm e}^{z_j}} softmax(zi)=j=1Kezjezi

其中, z i z_i zi是输入向量的第 i i i个元素, K K K是向量的长度。

Softmax 函数的实现

在 Python 中,可以使用 NumPy 库来实现 Softmax 函数。

import numpy as npdef softmax(x):exp_x = np.exp(x - np.max(x))  # 防止数值溢出return exp_x / np.sum(exp_x)# 示例输入
x = np.array([2.0, 1.0, 0.1])
print(softmax(x))

Softmax 函数的应用

Softmax 函数广泛应用于机器学习中的分类问题,特别是在神经网络的输出层。它可以将网络的原始输出转换为概率分布,从而方便地进行分类决策。

在使用 Softmax 函数时,需要注意数值稳定性问题。由于指数函数的增长非常快,直接计算 e z i e^{z_i} ezi可能导致数值溢出。为了避免这个问题,通常会从输入向量中减去其最大值,再进行指数计算。

def softmax_stable(x):exp_x = np.exp(x - np.max(x))return exp_x / np.sum(exp_x)

Softmax 函数的梯度

在反向传播算法中,需要计算 Softmax 函数的梯度。

Softmax 函数的梯度公式:

∂ softmax ( z i ) ∂ z j = softmax ( z i ) ( δ i j − softmax ( z j ) ) \frac{\partial \text{softmax}(z_i)}{\partial z_j} = \text{softmax}(z_i) (\delta_{ij} - \text{softmax}(z_j)) zjsoftmax(zi)=softmax(zi)(δijsoftmax(zj))

其中, δ i j \delta_{ij} δij是 Kronecker delta 函数,当 i = j i = j i=j时为 1,否则为 0。

交叉熵损失

交叉熵损失(Cross-Entropy Loss)是深度学习中常用的损失函数,尤其在分类任务中广泛应用。它衡量模型预测的概率分布与真实标签分布之间的差异。

对于二分类问题,交叉熵损失的公式为:
L = − 1 N ∑ i = 1 N [ y i log ⁡ ( p i ) + ( 1 − y i ) log ⁡ ( 1 − p i ) ] L = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(p_i) + (1 - y_i) \log(1 - p_i) \right] L=N1i=1N[yilog(pi)+(1yi)log(1pi)]
其中, y i y_i yi是真实标签(0 或 1), p i p_i pi是模型预测的概率, N N N是样本数量。

解释

二分类问题中,交叉熵损失函数可以表示为:

L = − [ y log ⁡ ( p ) + ( 1 − y ) log ⁡ ( 1 − p ) ] L = - \left[ y \log(p) + (1 - y) \log(1 - p) \right] L=[ylog(p)+(1y)log(1p)]其中, y y y是真实标签(0 或 1), p p p是模型预测为正样本的概率。

  • y = 1 y = 1 y=1时,损失函数简化为:
    L = − log ⁡ ( p ) L = - \log(p) L=log(p)即预测概率 p p p越接近 1,损失越小。

  • y = 0 y = 0 y=0时,损失函数简化为:
    L = − log ⁡ ( 1 − p ) L = - \log(1 - p) L=log(1p)即预测概率 p p p越接近 0,损失越小。

对于多分类问题,交叉熵损失的公式为:
L = − 1 N ∑ i = 1 N ∑ c = 1 K y i , c log ⁡ ( p i , c ) L = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{K} y_{i,c} \log(p_{i,c}) L=N1i=1Nc=1Kyi,clog(pi,c)
其中, y i , c y_{i,c} yi,c是样本 i i i在类别 c c c上的真实标签(0 或 1), p i , c p_{i,c} pi,c是模型预测的样本 i i i属于类别 c c c的概率, K K K是类别总数。

交叉熵损失当预测概率与真实标签一致时,损失值为 0。当预测概率与真实标签差异较大时,损失值会迅速增大,从而促使模型快速调整参数。

代码示例:交叉熵损失

def cross_entropy(y_true, y_pred):return -np.sum(y_true * np.log(y_pred))# 示例标签和预测
y_true = np.array([1, 0, 0])
y_pred = softmax(np.array([2.0, 1.0, 0.1]))print("Cross Entropy Loss:", cross_entropy(y_true, y_pred))
http://www.xdnf.cn/news/4940.html

相关文章:

  • ACTF2025 - Web writeup
  • C++编程语言:标准库:标准库概观(Bjarne Stroustrup)
  • 第六章 进阶09 我的人才观
  • 【设计模式】GoF设计模式之策略模式(Strategy Pattern)
  • rust 中的 EBNF 介绍
  • Uniapp编写微信小程序,使用canvas进行绘图
  • uni-app,小程序中的addPhoneContact,保存联系人到手机通讯录
  • 不止是UI库:React如何重塑前端开发范式?
  • Java中的内部类详解
  • iOS创建Certificate证书、制作p12证书流程
  • eNSP中路由器RIP协议配置完整实验实验和命令解释
  • Starrocks 的 ShortCircuit短路径
  • Rspress-快如闪电的静态站点生成器
  • Linux 学习笔记1
  • cilium路由模式和aws-eni模式下的IPAM
  • MySQL有哪些高可用方案?
  • CommunityToolkit.Mvvm详解
  • 前端面试每日三题 - Day 29
  • JavaScript性能优化实战,从理论到落地的全面指南
  • 阿里云 SLS 多云日志接入最佳实践:链路、成本与高可用性优化
  • webpack代理天地图瓦片
  • 【C++设计模式之Template Method Pattern】
  • mysql 已经初始化好,但是用 dbeaver 连接报错:Public Key Retrieval is not allowed
  • 2025数字孪生技术全景洞察:从工业革命到智慧城市的跨越式发展
  • Vue项目---懒加载的应用
  • Redhat 系统详解
  • 在linux系统中,没有网络如何生成流量以使得wireshark能捕获到流量
  • 数组和切片的区别
  • C#字段、属性、索引器、常量
  • 快速开发-基于gin的中间件web项目开发