当前位置: 首页 > backend >正文

【机器学习】 关于外插修正随机梯度方法的数值实验

1. 随机梯度下降(SGD)
  • 迭代格式
    x k + 1 = x k − η k ∇ f i ( x k ) x_{k+1} = x_k - \eta_k \nabla f_i(x_k) xk+1=xkηkfi(xk)
    其中, η k \eta_k ηk 为步长(可能递减), ∇ f i ( x k ) \nabla f_i(x_k) fi(xk) 是随机采样样本 i i i 的梯度估计。
  • 优点
  • 计算效率高,适合大规模数据集,每次迭代仅需单个样本的梯度 。
  • 在强凸问题中收敛速度为 O ( 1 / t ) O(1/t) O(1/t),非凸问题中为 O ( 1 / log ⁡ t ) O(1/\log t) O(1/logt)
  • 理论分析成熟,易于实现 。
  • 缺点
  • 收敛速度较慢,尤其在非凸问题中易陷入局部最优 。
  • 对步长敏感,需要精心调整参数以保证稳定性 。
2. 重球随机梯度方法(SHB)
  • 迭代格式
    x k + 1 = x k − η k ∇ f i ( x k ) + β ( x k − x k − 1 ) x_{k+1} = x_k - \eta_k \nabla f_i(x_k) + \beta (x_k - x_{k-1}) xk+1=xkηkfi(xk)+β(xkxk1)
    其中, β ∈ ( 0 , 1 ) \beta \in (0,1) β(0,1) 为动量参数,通过历史更新方向加速收敛。
  • 优点
    • 动量项可加速收敛,尤其在光滑强凸问题中表现优于固定步长的SGD 。
    • 对梯度噪声具有一定鲁棒性,通过历史梯度平均降低方差 。
  • 缺点
    • 早期迭代可能表现不佳,收敛速度不一定始终优于SGD 。
    • 参数选择(如 β \beta β η k \eta_k ηk)需谨慎,否则可能导致震荡或发散 。
    • 在有限和随机设置中,缺乏严格的加速收敛证明 。
3. Nesterov随机梯度方法(SNAG)
  • 迭代格式
    y k = x k + γ k ( x k − x k − 1 ) x k + 1 = y k − η k ∇ f i ( y k ) y_k = x_k + \gamma_k (x_k - x_{k-1}) \\ x_{k+1} = y_k - \eta_k \nabla f_i(y_k) yk=xk+γk(xkxk1)xk+1=ykηkfi(yk)
    其中, γ k \gamma_k γk 为动量系数,通常在Nesterov方法中设计为时变参数。
  • 优点
    • 在凸问题中理论收敛速度可达 O ( 1 / t 2 ) O(1/t^2) O(1/t2),显著快于SGD 。
    • 通过“前瞻梯度”设计,减少震荡并提高稳定性 。
    • 实验显示在分类和图像任务中优于传统动量方法 。
  • 缺点
    • 随机环境下(如有限和设置)可能发散,需额外条件保证收敛 。
    • 实现复杂度较高,需同时维护多个变量(如 x k x_k xk y k y_k yk)。
  • 参数调节更复杂,尤其在非凸问题中收敛性理论尚不完善 。

以上段落来自 秘塔 AI 综述的结果(先搜索后扩展选项, 文献均来自中英文论文而非全网)。该完整版请移步至链接

https://metaso.cn/s/ThPU2bK

以下我们给出一组实验来探讨 Nesterov 加速方法的参数选择, 收敛效果请大家自行验证,这里放上一个数值结果图作为代表
在这里插入图片描述

其中一点比较尴尬的现象是确定问题中 θ k = k − 1 k + 2 \theta_k=\frac{k-1}{k+2} θk=k+2k1 类型的外插参数在随机问题中的数值实验中的表现并不好,有一子列不收敛到0,但是仍有大量文献包括教材,论文仍然推荐使用这类策略。但是换成任何一个介于开区间 ( 0 , 1 ) (0,1) (0,1) 的常数,例如 0.9, 0.99 则有明显的序列收敛至0的趋势, 从本文给的算例来看是非常简单的凸二次 x 0 2 + x 1 2 + 2 ξ 0 x 0 + 2 ξ 1 x 0 x_0^2+x_1^2+2\xi_0 x_0+2\xi_1x_0 x02+x12+2ξ0x0+2ξ1x0,其中 ξ i \xi_i ξi 服从 N ( 0 , I ) N(0,I) N(0,I) 二维标准正态分布。为了压缩噪声影响,采用递减步长 α k = 1 ( k + 2 ) γ \alpha_k=\frac{1}{(k+2)^\gamma} αk=(k+2)γ1

  1. 规模小:仅2维问题
  2. 强凸
  3. 可微,且随机梯度关于自变量 x x x 是李普希兹连续的
  4. 随机样本噪声期望存在,方差有界

很难相信这样二维简单的例子参数 θ k = k − 1 k + 2 \theta_k=\frac{k-1}{k+2} θk=k+2k1 都不收敛,其在大规模以及大数据问题中会具有较好的收敛效果,欢迎大家参与实验与讨论。

Python 代码如下:

import numpy as np
import matplotlib.pyplot as plt
import numpy.linalg as la
iters=1000000
root=np.array([1.0,3.0])
vec1=root.copy()
vec2=root.copy()
dim=len(root)
path=np.zeros([iters,dim])
def gobj(x,xi):return(2*(x+xi))
gamma=1#  (k-1)/(k+2)  ===============================
np.random.seed(0)
for k in range(iters):    theta= (k-1)/(k+2)root=(1.0+theta)*vec2-theta*vec1a=1/(k+1)**gammaxi=np.random.randn(2)vec1=vec2.copy()vec2=root - a*gobj(root,xi)path[k,:]=root
V=np.zeros(iters)
for k in range(iters):V[k]=la.norm(path[k,:])
plt.loglog(V,'-.')
plt.grid(True)# 0.99    ===============================
iters=1000000
root=np.array([1.0,3.0])
vec1=root.copy()
vec2=root.copy()
dim=len(root)
path=np.zeros([iters,dim])
np.random.seed(0)
for k in range(iters):    theta= 0.99root=(1.0+theta)*vec2-theta*vec1a=1/(k+1)**gammaxi=np.random.randn(2)vec1=vec2.copy()vec2=root - a*gobj(root,xi)path[k,:]=root
V=np.zeros(iters)
for k in range(iters):V[k]=la.norm(path[k,:])
plt.loglog(V,'--')
plt.grid(True)# 0.9  ===============================
iters=1000000
root=np.array([1.0,3.0])
vec1=root.copy()
vec2=root.copy()
dim=len(root)
path=np.zeros([iters,dim])
np.random.seed(0)
for k in range(iters):    theta= 0root=(1.0+theta)*vec2-theta*vec1a=1/(k+1)**gammaxi=np.random.randn(2)vec1=vec2.copy()vec2=root - a*gobj(root,xi)path[k,:]=root
V=np.zeros(iters)
for k in range(iters):V[k]=la.norm(path[k,:])
plt.loglog(V,'.-')
plt.grid(True)plt.legend(['(k-1)/(k+2)',0.99,0.5,'2/(k+2)'])
plt.show()

Matlab 代码如下

% (k-1)/(k+2)   ===============================
init=[1,3];
lth=length(init);
fobj=@(x,xi)(x*x'+2*xi*x');
gobj=@(x,xi)(2*x+2*xi);
iters=1000000;
path=ones(iters+1,length(init));
path(1,:)=init;
root=init;
randn('seed',1)
for k =1:itersif k<2xi=randn(1,lth);a=1/(k+2)^(2/3);root=root-a*gobj(root,xi);path(k+1,:)=root;elsexi=randn(1,lth);a=1/(k+2)^(2/3);v=root-a*gobj(root,xi);path(k+1,:)=v;theta=(k-1)/(k+2);th=theta;root=(1+th)*path(k+1,:)-theta*path(k,:);end
end
Vk=ones(iters+1,1);
for k=1:iters+1Vk(k)= path(k,:)*path(k,:)';
end
loglog(Vk,'--')
grid on;
hold on;% theta=0.99    ===============================
init=[1,3];
iters=1000000;
path=ones(iters+1,length(init));
path(1,:)=init;
root=init;
randn('seed',1)
for k =1:itersif k<2xi=randn(1,lth);a=1/(k+2)^(2/3);root=root-a*gobj(root,xi);path(k+1,:)=root;elsexi=randn(1,lth);a=1/(k+2)^(2/3);v=root-a*gobj(root,xi);path(k+1,:)=v;theta=0.99;th=theta;root=(1+th)*path(k+1,:)-theta*path(k,:);end
end
Vk=ones(iters+1,1);
for k=1:iters+1Vk(k)= path(k,:)*path(k,:)';
end
loglog(Vk,'--')
grid on;
hold on;% theta=0.9     ===============================
init=[1,3];
iters=1000000;
path=ones(iters+1,length(init));
path(1,:)=init;
root=init;
randn('seed',1)
for k =1:itersif k<2xi=randn(1,lth);a=1/(k+2)^(2/3);root=root-a*gobj(root,xi);path(k+1,:)=root;elsexi=randn(1,lth);a=1/(k+2)^(2/3);v=root-a*gobj(root,xi);path(k+1,:)=v;theta=0.9;th=theta;root=(1+th)*path(k+1,:)-theta*path(k,:);end
end
Vk=ones(iters+1,1);
for k=1:iters+1Vk(k)= path(k,:)*path(k,:)';
end
loglog(Vk,'--')
grid on;
hold on;% theta=0.9  ===================================================================
init=[1,3];iters=1000000;
path=ones(iters+1,length(init));
path(1,:)=init;
root=init;
randn('seed',1)
for k =1:itersif k<2xi=randn(1,lth)a=1/(k+2)^(2/3);root=root-a*gobj(root,xi);path(k+1,:)=root;elsexi=randn(1,lth);a=1/(k+2)^(2/3);v=root-a*gobj(root,xi);path(k+1,:)=v;theta=0.5;th=theta;root=(1+th)*path(k+1,:)-theta*path(k,:);end
end
Vk=ones(iters+1,1);
for k=1:iters+1Vk(k)= path(k,:)*path(k,:)';
end
loglog(Vk,'--')
grid on;
hold on;
legend('(k-1)/(k+2)','0.99','0.9','0.5')
http://www.xdnf.cn/news/8245.html

相关文章:

  • 听脑AI:革新沟通方式,开启高效信息时代
  • 核实发票的真实性与合法性-发票查验接口-虚假发票防范
  • 关于Newtonsoft版本不兼容问题处理
  • sentinel滑动时间窗口算法详解
  • 系统性能分析基本概念(3) : Tuning Efforts
  • imuerrset
  • PT8P2104触控型8Bit MCU
  • 【Django Serializer】一篇文章详解 Django 序列化器
  • deep-rtsp 摄像头rtsp配置工具
  • 高频与超高频RFID读写器技术应用差异解析
  • 论文解读: 2018-Detection of spam reviews: a sentiment analysis approach
  • 宝尊电商一季度净收入21亿元 品牌管理收入同比大增
  • 冲刺卷软考总结-案例分析
  • 地信GIS专业关于学习、考研、就业方面的一些问题答疑
  • Windows系统下Docker安装青龙面板
  • 常见高危端口解析:网络安全中的“危险入口”
  • 101个α因子#15
  • CentOS7安装 PHP-FPM 7.4
  • 2025海外短剧CPS系统开发指南:高付费市场解析与增速全景图
  • 【CSS】九宫格布局
  • openEuler 22.03 LTS-SP3 系统安装 docker 26.1.3、docker-compose
  • 微信小程序之Promise-Promise初始用
  • 笔记:将一个文件服务器上的文件(一个返回文件数据的url)作为另一个http接口的请求参数
  • 重读《人件》Peopleware -(11)Ⅱ 办公环境 Ⅳ 插曲:生产力测量与不明飞行物
  • Nginx核心功能
  • 【Linux系统】冯诺依曼体系结构 和 操作系统的介绍
  • Ctrl+鼠标滚动阻止页面放大/缩小
  • QFileDialog::getSaveFileName导致系统崩溃
  • Go语言gopacket库的HTTP协议分析工具实现
  • 学习人工智能:从0到1的破局指南与职业成长路径