由sigmod权重曲线存在锯齿的探索
深度学习的知识点,一般按照执行流程,有 网络层类型,归一化,激活函数,学习率,损失函数,优化器。如果是研究生上课学的应该系统一点,自学的话知识点一开始有点乱。
一、激活函数Sigmod下的权重曲线
先开始第一个知识点激活函数的探索,说到激活函数,它是干嘛的, y=f(w*x+b)。这y就是激活函数。那本文研究的是sigmod这个激活函数权重曲线为何出现锯齿?
首先大家熟悉下,sigmod函数,见下图,特点是什么? y值范围(0,1),x在(-5,5)之间呈现线性特征。
然后不废话,上代码,先看看锯齿的效果。
import torch
import matplotlib.pyplot as pltmodel = torch.nn.Sequential(torch.nn.Linear(1,1,bias=False),# torch.nn.BatchNorm1d(1),torch.nn.Sigmoid()
)# torch.nn.init.xavier_normal_(model[0].weight)X=torch.tensor([[10.0],[2.0],[11.0],[21.0]])
Y=torch.tensor([[0.9],[0.2],[0.92],[0.987]])optimizer = torch.optim.SGD(model.paramters(),lr=1)# 使用动量的SGD
optimizer = torch.optim.SGD(model.paramters(),lr=1,momentum=0.9)weight_history=[]
for step in range(100):optimizer.zero_grad()y_pred=model(X)loss = (y_pred-Y).pow(2).mean()loss.backward()optimizer.step()weight_history.append(model[0].weight.item())plt.plot(weight_history,'-',label='Sigmoid',color='orange')
plt.xlabel('Training Step')
plt.ylabel('Weight Value')
plt.title('Sigmoid Activation:Weight Oscillation')
plt.legend()
plt.savefig('juchi.jpg')
二、代码解析-kaiming均匀分布
我们先来看 torch.nn.Linear(1,1,bias=False)这句,这是一个线性函数,y=wx+b;
那么w和b为多少呢? bias为空,w呢,随机取值,但是默认符合kaiming均匀分布。
这个分布的范围呢?[-根号(1/特征数),根号(1/特征数)]。为何要这样搞,之前大伙确实随机权重,但是发展到一定阶段,发现规律了,如果这些权重符合某种分布,就能更好稳定训练。(对了,什么叫稳定训练,就是各层数据差异别太大,每一层的方差差不多的时候最好)
Kaiming初始化如何维持输入分布稳定?大家想想,对于y=f(wx+b);x是输入不能变的,激活函数也不能变。但是w可变的,通过改w可以使下一层稳定。
他让var(就是方差), var(x)=var(y)。然后var(y)=n_feautures*var(w)*var(x)。从而算出var(w)。然后均匀分布方差是 (b-a)平方/12。最终算出均匀分布的边界值(a,b)。说那么多就是为了算出均匀分布的边界值,然后才能随机嘛
三、产生锯齿的原因
产生锯齿的原因有很多种,例如学习率,数据值分布。今天我们探究的是非零中心性这个原因
Sigmoid的输出恒为正,可能导致后续层的权重梯度同号(如全正或全负),迫使优化路径呈“Z”字形调整,产生锯齿。
具体说明如下:
假设某全连接层的权重为 W,输入为 a(来自Sigmoid激活),输出为 z=W⋅a+b,损失函数为 L。根据链式法则,权重的梯度为:∂W/∂L=∂z/∂L⋅∂W/∂z=δ⋅a
其中:
- δ=∂z/∂L 是反向传播的误差项,
- a 是Sigmoid的输出,恒为正(a>0)。
关键点:梯度的符号由 δ 决定,而 δ 在反向传播中可能保持同号(全正或全负)。啥意思呢,就是这次梯度大了,然后反向时变小一点,然后下一次又小了,需要变大一点。
四、怎么解决呢?
1、入参归一化
torch.nn.BatchNorm1d(1),把数据搞成 均值为0,方差为1。这中学就学过吧。归一化的作用一句话讲,在下一层改变sigmoid输出的恒大于0的值,就是把x都变了,变成有负数了
2、使用动量的SGD
optimizer = torch.optim.SGD(model.paramters(),lr=1,momentum=0.9)
动量,白话讲,就是记住历史的变化,并非只依靠当前的,所以震荡小。好理解吧,就是船大不好掉头,船小好掉头。老外尽整一些新名词唬人。
3、换个方法,Tanh(),不用sigmod
但是这个具体看业务,并一定就能换