当前位置: 首页 > news >正文

Relook:softmax函数

1. 计算 Z = W T X + b Z=W^TX+\boldsymbol{b} Z=WTX+b

定义变量

定义 X ∈ R D in × N , W ∈ R D in × D out , b ∈ R D out X \in \mathbb{R}^{D_{\text{in}} \times N}, W \in \mathbb{R}^{D_{\text{in}} \times D_{\text{out}}}, b \in \mathbb{R}^{D_{\text{out}}} XRDin×N,WRDin×Dout,bRDout,则

X = [ x 11 x 12 ⋯ x 1 N x 21 x 22 ⋯ x 2 N ⋮ ⋮ ⋮ ⋮ x D i n 1 x D i n 2 ⋯ x D i n N ] = [ x 1 x 2 ⋯ x N ] X=\begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1N} \\ x_{21} & x_{22} & \cdots & x_{2N} \\ \vdots & \vdots & \vdots & \vdots \\ x_{D_{in}1} & x_{D_{in}2} & \cdots & x_{D_{in}N} \\ \end{bmatrix}= \begin{bmatrix} \boldsymbol{x}_1 \boldsymbol \ \boldsymbol{x}_2 \cdots \boldsymbol{x}_N\\ \end{bmatrix} X= x11x21xDin1x12x22xDin2x1Nx2NxDinN =[x1 x2xN]

式中, x i = [ x 1 i , x 2 i , ⋯ , x D i n i ] T \boldsymbol{x_i}=[x_{1i}, x_{2i}, \cdots, x_{D_{in}i}]^T xi=[x1i,x2i,,xDini]T,每一列为一个样本,且样本是用列向量表示。

W = [ w 11 w 12 ⋯ w 1 D o u t w 21 w 22 ⋯ w 2 D o u t ⋮ ⋮ ⋱ ⋮ w D i n 1 w D i n 2 ⋯ w D i n D o u t ] = [ w 1 w 2 ⋯ w D o u t ] W = \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1D_{out}} \\ w_{21} & w_{22} & \cdots & w_{2D_{out}} \\ \vdots & \vdots & \ddots & \vdots \\ w_{D_{in}1} & w_{D_{in}2} & \cdots & w_{D_{in}D_{out}} \\ \end{bmatrix}= \begin{bmatrix} \boldsymbol{w}_1 \boldsymbol \ \boldsymbol{w}_2 \cdots \boldsymbol{w}_{D_{out}} \end{bmatrix} W= w11w21wDin1w12w22wDin2w1Doutw2DoutwDinDout =[w1 w2wDout]

式中, w i = [ w 1 i , w 2 i , ⋯ , w D i n i ] T \boldsymbol{w}_i=[w_{1i}, w_{2i}, \cdots, w_{D_{in}i}]^T wi=[w1i,w2i,,wDini]T,每一列为一个权重向量。

计算权重转置 W T W^T WT

W T = [ w 11 w 21 ⋯ w D i n 1 w 12 w 22 ⋯ w D i n 2 ⋮ ⋮ ⋱ ⋮ w 1 D o u t w 2 D o u t ⋯ w D i n D o u t ] = [ w 1 T w 2 T ⋮ w D o u t T ] W^T = \begin{bmatrix} w_{11} & w_{21} & \cdots & w_{D_{in}1} \\ w_{12} & w_{22} & \cdots & w_{D_{in}2} \\ \vdots & \vdots & \ddots & \vdots \\ w_{1D_{out}} & w_{2D_{out}} & \cdots & w_{D_{in}D_{out}} \\ \end{bmatrix}= \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} WT= w11w12w1Doutw21w22w2DoutwDin1wDin2wDinDout = w1Tw2TwDoutT

注意,由于 w i \boldsymbol{w}_i wi是一个列向量,则对应的转置即为行向量。

计算线性变换 W T X W^TX WTX

W T X = [ w 1 T w 2 T ⋮ w D o u t T ] [ x 1 x 2 ⋯ x N ] = [ w 1 T x 1 w 1 T x 2 ⋯ w 1 T x N w 2 T x 1 w 2 T x 2 ⋯ w 2 T x N ⋮ ⋮ ⋱ ⋮ w D o u t T x 1 w D o u t T x 2 ⋯ w D o u t T x N ] W^T X = \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} \begin{bmatrix} \boldsymbol{x}_1 \boldsymbol \ \boldsymbol{x}_2 \cdots \boldsymbol{x}_N\\ \end{bmatrix}= \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x}_1 & \boldsymbol{w}_1^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_1^T \boldsymbol{x}_N \\ \boldsymbol{w}_2^T \boldsymbol{x}_1 & \boldsymbol{w}_2^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_2^T \boldsymbol{x}_N \\ \vdots & \vdots & \ddots & \vdots \\ \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_1 & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_N \\ \end{bmatrix} WTX= w1Tw2TwDoutT [x1 x2xN]= w1Tx1w2Tx1wDoutTx1w1Tx2w2Tx2wDoutTx2w1TxNw2TxNwDoutTxN

为了更加直观的观察,采用 A [ x 1 x 2 ⋯ x n ] = [ A x 1 A x 2 ⋯ A x n ] A[x_1\ x_2\ \cdots\ x_n]=[Ax_1\ Ax_2\ \cdots\ Ax_n] A[x1 x2  xn]=[Ax1 Ax2  Axn]这种矩阵乘法的表达形式,则上式可以写成:
W T X = [ [ w 1 T w 2 T ⋮ w D o u t T ] x 1 [ w 1 T w 2 T ⋮ w D o u t T ] x 2 ⋯ [ w 1 T w 2 T ⋮ w D o u t T ] x N ] W^T X = \begin{bmatrix} \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} \boldsymbol{x}_1 & \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} \boldsymbol{x}_2 & \cdots & \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} \boldsymbol{x}_N \end{bmatrix} WTX= w1Tw2TwDoutT x1 w1Tw2TwDoutT x2 w1Tw2TwDoutT xN

添加偏置 b \boldsymbol{b} b(广播机制)

W T X + b = [ w 1 T x 1 w 1 T x 2 ⋯ w 1 T x N w 2 T x 1 w 2 T x 2 ⋯ w 2 T x N ⋮ ⋮ ⋱ ⋮ w D o u t T x 1 w D o u t T x 2 ⋯ w D o u t T x N ] + [ b 1 b 2 ⋮ b D o u t ] = [ w 1 T x 1 + b 1 w 1 T x 2 + b 1 ⋯ w 1 T x N + b 1 w 2 T x 1 + b 2 w 2 T x 2 + b 2 ⋯ w 2 T x N + b 2 ⋮ ⋮ ⋱ ⋮ w D o u t T x 1 + b D o u t w D o u t T x 2 + b D o u t ⋯ w D o u t T x N + b D o u t ] D o u t × N \begin{align} W^T X + b &= \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x}_1 & \boldsymbol{w}_1^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_1^T \boldsymbol{x}_N \\ \boldsymbol{w}_2^T \boldsymbol{x}_1 & \boldsymbol{w}_2^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_2^T \boldsymbol{x}_N \\ \vdots & \vdots & \ddots & \vdots \\ \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_1 & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_2 & \cdots & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_N \\ \end{bmatrix}+ \begin{bmatrix} b_1 \\ b_2 \\ \vdots \\ b_{D_{out}} \\ \end{bmatrix} \\&= \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x}_1 + b_1 & \boldsymbol{w}_1^T \boldsymbol{x}_2 + b_1 & \cdots & \boldsymbol{w}_1^T \boldsymbol{x}_N + b_1 \\ \boldsymbol{w}_2^T \boldsymbol{x}_1 + b_2 & \boldsymbol{w}_2^T \boldsymbol{x}_2 + b_2 & \cdots & \boldsymbol{w}_2^T \boldsymbol{x}_N + b_2 \\ \vdots & \vdots & \ddots & \vdots \\ \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_1 + b_{D_{out}} & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_2 + b_{D_{out}} & \cdots & \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_N + b_{D_{out}} \\ \end{bmatrix}_{D_{out}\times N} \end{align} WTX+b= w1Tx1w2Tx1wDoutTx1w1Tx2w2Tx2wDoutTx2w1TxNw2TxNwDoutTxN + b1b2bDout = w1Tx1+b1w2Tx1+b2wDoutTx1+bDoutw1Tx2+b1w2Tx2+b2wDoutTx2+bDoutw1TxN+b1w2TxN+b2wDoutTxN+bDout Dout×N

输出矩阵 Z Z Z

Z = W T X + b = [ z 11 z 12 ⋯ z 1 N z 21 z 22 ⋯ z 2 N ⋮ ⋮ ⋱ ⋮ z D o u t 1 z D o u t 2 ⋯ z D o u t N ] D o u t × N Z = W^T X + b = \begin{bmatrix} z_{11} & z_{12} & \cdots & z_{1N} \\ z_{21} & z_{22} & \cdots & z_{2N} \\ \vdots & \vdots & \ddots & \vdots \\ z_{D_{out}1} & z_{D_{out}2} & \cdots & z_{D_{out}N} \\ \end{bmatrix}_{D_{out}\times N} Z=WTX+b= z11z21zDout1z12z22zDout2z1Nz2NzDoutN Dout×N

因此,对于任意一列样本 x i \boldsymbol{x}_i xi,对应的输出为

z i = [ w 1 T x i + b 1 w 2 T x i + b 2 ⋮ w D o u t T x i + b D o u t ] = [ w 1 T w 2 T ⋮ w D o u t T ] x i + [ b 1 b 2 ⋮ b D o u t ] \boldsymbol{z}_i = \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x}_i + b_1 \\ \boldsymbol{w}_2^T \boldsymbol{x}_i + b_2 \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \boldsymbol{x}_i + b_{D_{out}} \\ \end{bmatrix}= \begin{bmatrix} \boldsymbol{w}_1^T \\ \boldsymbol{w}_2^T \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \\ \end{bmatrix} \boldsymbol{x}_i + \begin{bmatrix} b_1 \\ b_2 \\ \vdots \\ b_{D_{out}} \\ \end{bmatrix} zi= w1Txi+b1w2Txi+b2wDoutTxi+bDout = w1Tw2TwDoutT xi+ b1b2bDout

2. 点积(内积/标量积)定义

点积(dot product)、标量积(scalar product)、内积(inner product)都是同一种意义,叫法不同。其中,点积是基于代数运行,两个长度相等的数列对应位置相乘,然后求和。标量积,是因为运算的结果为一个数,即标量值,所以称标量积。从几何角度出发,在欧几里空间中,两个向量的欧几里得模长与它们夹角余弦值的乘积,称为内积。

坐标定义

对于两个长度相等的列向量 a = [ a 1 , a 2 , ⋯ , a n ] T , b = [ b 1 , b 2 , ⋯ , b n ] T \boldsymbol{a} = [a_1, a_2, \cdots, a_n]^T,\boldsymbol{b} = [b_1, b_2, \cdots, b_n]^T a=[a1,a2,,an]T,b=[b1,b2,,bn]T,它们的点积定义如下:
a ⋅ b = ∑ i = 1 n a i b i = a 1 b 1 + a 2 b 2 + ⋯ + a n b n \boldsymbol{a} \cdot \boldsymbol{b} = \sum_{i = 1}^{n} a_i b_i = a_1b_1 + a_2b_2 + \cdots + a_nb_n ab=i=1naibi=a1b1+a2b2++anbn
同样地,可以采用矩阵乘法的形式进行运算:

a ⋅ b = a T b = [ a 1 a 2 ⋯ a n ] × [ b 1 b 2 ⋮ b n ] = a 1 b 1 + a 2 b 2 + ⋯ + a n b n \begin{align} \boldsymbol{a} \cdot \boldsymbol{b} &= \boldsymbol{a}^T\boldsymbol{b}= [a_1 \ a_2\ \cdots a_n] \times \begin{bmatrix} b_1 \\ b_2 \\ \vdots \\ b_n \\ \end{bmatrix}=a_1b_1 + a_2b_2 + \cdots + a_nb_n \end{align} ab=aTb=[a1 a2 an]× b1b2bn =a1b1+a2b2++anbn

几何定义

在欧几里得空间中,欧几里得向量是一种兼具大小(模长)和方向的几何对象,两个欧几里得向量的点积定义为:
a ⋅ b = ∣ ∣ a ∣ ∣ ∣ ∣ b ∣ ∣ cos ⁡ ( θ ) \boldsymbol{a} \cdot \boldsymbol{b} = ||\boldsymbol{a}||\ ||\boldsymbol{b}||\cos(\theta) ab=∣∣a∣∣ ∣∣b∣∣cos(θ)

3. 重新思考Softmax

这里忽略偏置 b \boldsymbol{b} b,对于任意一个样本 x \boldsymbol{x} x,全连接的分类输出为
z = [ w 1 T x w 2 T x ⋮ w D o u t T x ] \boldsymbol{z} = \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x} \\ \boldsymbol{w}_2^T \boldsymbol{x} \\ \vdots \\ \boldsymbol{w}_{D_{out}}^T \boldsymbol{x} \\ \end{bmatrix} z= w1Txw2TxwDoutTx
进一步,假设是3分类问题,Fc的输出维度 D o u t = 3 D_{out}=3 Dout=3,则
z = [ w 1 T x w 2 T x w 3 T x ] \boldsymbol{z} = \begin{bmatrix} \boldsymbol{w}_1^T \boldsymbol{x} \\ \boldsymbol{w}_2^T \boldsymbol{x} \\ \boldsymbol{w}_3^T \boldsymbol{x} \\ \end{bmatrix} z= w1Txw2Txw3Tx
根据前面条件可知, x \boldsymbol{x} x是列向量, w i \boldsymbol{w}_i wi也是列向量,则
w i T x = ∣ ∣ w i T ∣ ∣ ∣ ∣ x ∣ ∣ cos ⁡ ( θ i ) \boldsymbol{w}_i^T \boldsymbol{x} =||\boldsymbol{w}_i^T||\ ||\boldsymbol{x}||\cos(\theta_i) wiTx=∣∣wiT∣∣ ∣∣x∣∣cos(θi)
因此,可以表达为下图内容所示。基于三分类的问题,若某个类别的softmax输出概率最大,则认为输入样本 x \boldsymbol{x} x为该类别。要使某个某个的概率最大,则必须使得内积的结果最大。可以看出,内积的大写与两个变量有关,即 w i , θ i \boldsymbol{w}_i,\theta_i wi,θi

在这里插入图片描述

当输入样本 x \boldsymbol{x} x与某个类别权重 w i \boldsymbol{w}_i wi的夹角越小,且该类别权重的模长越长,则样本 x \boldsymbol{x} x被预测为该类别的概率就越大。

4. 总结

上述相关对softmax进行魔改的工作相当多,可以参考人脸识别领域的相关改进softmax loss工作,如:

  • [1612.02295] Large-Margin Softmax Loss for Convolutional Neural Networks
  • [1704.08063] SphereFace: Deep Hypersphere Embedding for Face Recognition
  • [1801.09414] CosFace: Large Margin Cosine Loss for Deep Face Recognition
  • [1801.05599] Additive Margin Softmax for Face Verification
  • ArcFace: Additive Angular Margin Loss for Deep Face Recognition

个人感觉,改修的各种版本softmax loss具有明确的可解释性,在一定程度上可能会提升收敛速度,但是准确率的提升可能不会太明显。(目前还没在代码中体验)

http://www.xdnf.cn/news/1041625.html

相关文章:

  • 状态机(State Machine)详解
  • 车载功能框架 --- 整车安全策略
  • 第六届经济管理与大数据应用国际学术会议 (ICEMBDA 2025)
  • 数据库学习(六)——MySQL事务
  • QT打包应用
  • 天邑TEWA-808AE高安版_S905L3B融合机破解TTL刷机包
  • python做题日记(17)
  • 15.vue.js的watch()和watchEffect()(2)
  • JAVA理论第十八章-JWT杂七杂八
  • Visualized_BGE 安装—多模态嵌入技术
  • Java 复习题选择题(1)(Java概述)
  • LLMs 系列实操科普(5)
  • 【卫星通信】Skylo与ViaSat标准建议详解:基于NB-IoT NTN通过GEO卫星实现IMS语音通话的解决方案
  • springboot在线BLOG网
  • SCADA|信创KingSCADA4.0历史报警查询的差异
  • 永磁同步电机控制算法--双矢量模型预测转矩控制MPTC(占空比)
  • [直播推流] 本地创建 nginx服务器
  • DataHub 架构设计与模块规划
  • 深度解析SpringBoot自动化部署实战:从原理到最佳实践
  • Android 安卓应用分身多开 适用于没有自带分身多开的Android设备,隐藏应用、应用锁、私密相册等管理,解锁永久Vip会员功能
  • 【精华】这样设计高性能短链生成系统
  • 记利用AI模型制作DataDump Scripts生成工具
  • 理解 C++ 的 this 指针
  • Seata与消息队列(如RocketMQ)如何实现最终一致性?
  • 【构建】CMake 构建系统重点内容
  • springboot音乐网站与分享平台
  • MySQL-DML语句深度解析与实战指南
  • 60天python训练计划----day52
  • Golang 在 Linux 平台上的并发控制
  • LeetCode - LCR 173. 点名