1. 随机梯度下降(SGD)
- 迭代格式:
x k + 1 = x k − η k ∇ f i ( x k ) x_{k+1} = x_k - \eta_k \nabla f_i(x_k) xk+1=xk−ηk∇fi(xk)
其中, η k \eta_k ηk 为步长(可能递减), ∇ f i ( x k ) \nabla f_i(x_k) ∇fi(xk) 是随机采样样本 i i i 的梯度估计。
- 优点:
- 盘算效率高,得当大规模数据集,每次迭代仅需单个样本的梯度 。
- 在强凸标题中收敛速度为 O ( 1 / t ) O(1/t) O(1/t),非凸标题中为 O ( 1 / log t ) O(1/\log t) O(1/logt) 。
- 理论分析成熟,易于实现 。
- 缺点:
- 收敛速度较慢,尤其在非凸标题中易陷入局部最优 。
- 对步长敏感,需要经心调解参数以包管稳定性 。
2. 重球随机梯度方法(SHB)
- 迭代格式:
x k + 1 = x k − η k ∇ f i ( x k ) + β ( x k − x k − 1 ) x_{k+1} = x_k - \eta_k \nabla f_i(x_k) + \beta (x_k - x_{k-1}) xk+1=xk−ηk∇fi(xk)+β(xk−xk−1)
其中, β ∈ ( 0 , 1 ) \beta \in (0,1) β∈(0,1) 为动量参数,通过汗青更新方向加速收敛。
- 优点:
- 动量项可加速收敛,尤其在光滑强凸标题中体现优于固定步长的SGD 。
- 对梯度噪声具有肯定鲁棒性,通过汗青梯度平均低落方差 。
- 缺点:
- 早期迭代可能体现不佳,收敛速度不肯定始终优于SGD 。
- 参数选择(如 β \beta β 和 η k \eta_k ηk)需谨慎,否则可能导致震荡或发散 。
- 在有限和随机设置中,缺乏严格的加速收敛证明 。
3. Nesterov随机梯度方法(SNAG)
- 迭代格式:
y k = x k + γ k ( x k − x k − 1 ) x k + 1 = y k − η k ∇ f i ( y k ) y_k = x_k + \gamma_k (x_k - x_{k-1}) \\ x_{k+1} = y_k - \eta_k \nabla f_i(y_k) yk=xk+γk(xk−xk−1)xk+1=yk−ηk∇fi(yk)
其中, γ k \gamma_k γk 为动量系数,通常在Nesterov方法中设计为时变参数。
- 优点:
- 在凸标题中理论收敛速度可达 O ( 1 / t 2 ) O(1/t^2) O(1/t2),显著快于SGD 。
- 通过“前瞻梯度”设计,淘汰震荡并进步稳定性 。
- 实验显示在分类和图像任务中优于传统动量方法 。
- 缺点:
- 随机情况下(如有限和设置)可能发散,需额外条件包管收敛 。
- 实现复杂度较高,需同时维护多个变量(如 x k x_k xk 和 y k y_k yk)。
- 参数调治更复杂,尤其在非凸标题中收敛性理论尚不美满 。
以上段落来自 秘塔 AI 综述的效果(先搜刮后扩展选项, 文献均来自中英文论文而非全网)。该完整版请移步至链接
https://metaso.cn/s/ThPU2bK
以下我们给出一组实验来探究 Nesterov 加速方法的参数选择, 收敛效果请大家自行验证,这里放上一个数值效果图作为代表
其中一点比较尴尬的现象是确定标题中 θ k = k − 1 k + 2 \theta_k=\frac{k-1}{k+2} θk=k+2k−1 类型的外插参数在随机标题中的数值实验中的体现并不好,有一子列不收敛到0,但是仍有大量文献包括课本,论文仍旧保举使用这类计谋。但是换成任何一个介于开区间 ( 0 , 1 ) (0,1) (0,1) 的常数,例如 0.9, 0.99 则有显着的序列收敛至0的趋势, 从本文给的算例来看是非常简单的凸二次 x 0 2 + x 1 2 + 2 ξ 0 x 0 + 2 ξ 1 x 0 x_0^2+x_1^2+2\xi_0 x_0+2\xi_1x_0 x02+x12+2ξ0x0+2ξ1x0,其中 ξ i \xi_i ξi 服从 N ( 0 , I ) N(0,I) N(0,I) 二维标准正态分布。为了压缩噪声影响,接纳递减步长 α k = 1 ( k + 2 ) γ \alpha_k=\frac{1}{(k+2)^\gamma} αk=(k+2)γ1。
- 规模小:仅2维标题
- 强凸
- 可微,且随机梯度关于自变量 x x x 是李普希兹连续的
- 随机样本噪声期望存在,方差有界
很难相信如许二维简单的例子参数 θ k = k − 1 k + 2 \theta_k=\frac{k-1}{k+2} θk=k+2k−1 都不收敛,其在大规模以及大数据标题中会具有较好的收敛效果,接待大家参与实验与讨论。
Python 代码如下:
- import numpy as np
- import matplotlib.pyplot as plt
- import numpy.linalg as la
- iters=1000000
- root=np.array([1.0,3.0])
- vec1=root.copy()
- vec2=root.copy()
- dim=len(root)
- path=np.zeros([iters,dim])
- def gobj(x,xi):
- return(2*(x+xi))
- gamma=1
- # (k-1)/(k+2) ===============================
- np.random.seed(0)
- for k in range(iters):
- theta= (k-1)/(k+2)
- root=(1.0+theta)*vec2-theta*vec1
- a=1/(k+1)**gamma
- xi=np.random.randn(2)
- vec1=vec2.copy()
- vec2=root - a*gobj(root,xi)
- path[k,:]=root
- V=np.zeros(iters)
- for k in range(iters):
- V[k]=la.norm(path[k,:])
- plt.loglog(V,'-.')
- plt.grid(True)
- # 0.99 ===============================
- iters=1000000
- root=np.array([1.0,3.0])
- vec1=root.copy()
- vec2=root.copy()
- dim=len(root)
- path=np.zeros([iters,dim])
- np.random.seed(0)
- for k in range(iters):
- theta= 0.99
- root=(1.0+theta)*vec2-theta*vec1
- a=1/(k+1)**gamma
- xi=np.random.randn(2)
- vec1=vec2.copy()
- vec2=root - a*gobj(root,xi)
- path[k,:]=root
- V=np.zeros(iters)
- for k in range(iters):
- V[k]=la.norm(path[k,:])
- plt.loglog(V,'--')
- plt.grid(True)
- # 0.9 ===============================
- iters=1000000
- root=np.array([1.0,3.0])
- vec1=root.copy()
- vec2=root.copy()
- dim=len(root)
- path=np.zeros([iters,dim])
- np.random.seed(0)
- for k in range(iters):
- theta= 0
- root=(1.0+theta)*vec2-theta*vec1
- a=1/(k+1)**gamma
- xi=np.random.randn(2)
- vec1=vec2.copy()
- vec2=root - a*gobj(root,xi)
- path[k,:]=root
- V=np.zeros(iters)
- for k in range(iters):
- V[k]=la.norm(path[k,:])
- plt.loglog(V,'.-')
- plt.grid(True)
- plt.legend(['(k-1)/(k+2)',0.99,0.5,'2/(k+2)'])
- plt.show()
复制代码 Matlab 代码如下
- % (k-1)/(k+2) ===============================
- init=[1,3];
- lth=length(init);
- fobj=@(x,xi)(x*x'+2*xi*x');
- gobj=@(x,xi)(2*x+2*xi);
- iters=1000000;
- path=ones(iters+1,length(init));
- path(1,:)=init;
- root=init;
- randn('seed',1)
- for k =1:iters
- if k<2
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- root=root-a*gobj(root,xi);
- path(k+1,:)=root;
- else
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- v=root-a*gobj(root,xi);
- path(k+1,:)=v;
- theta=(k-1)/(k+2);
- th=theta;
- root=(1+th)*path(k+1,:)-theta*path(k,:);
- end
- end
- Vk=ones(iters+1,1);
- for k=1:iters+1
- Vk(k)= path(k,:)*path(k,:)';
- end
- loglog(Vk,'--')
- grid on;
- hold on;
- % theta=0.99 ===============================
- init=[1,3];
- iters=1000000;
- path=ones(iters+1,length(init));
- path(1,:)=init;
- root=init;
- randn('seed',1)
- for k =1:iters
- if k<2
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- root=root-a*gobj(root,xi);
- path(k+1,:)=root;
- else
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- v=root-a*gobj(root,xi);
- path(k+1,:)=v;
- theta=0.99;
- th=theta;
- root=(1+th)*path(k+1,:)-theta*path(k,:);
- end
- end
- Vk=ones(iters+1,1);
- for k=1:iters+1
- Vk(k)= path(k,:)*path(k,:)';
- end
- loglog(Vk,'--')
- grid on;
- hold on;
- % theta=0.9 ===============================
- init=[1,3];
- iters=1000000;
- path=ones(iters+1,length(init));
- path(1,:)=init;
- root=init;
- randn('seed',1)
- for k =1:iters
- if k<2
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- root=root-a*gobj(root,xi);
- path(k+1,:)=root;
- else
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- v=root-a*gobj(root,xi);
- path(k+1,:)=v;
- theta=0.9;
- th=theta;
- root=(1+th)*path(k+1,:)-theta*path(k,:);
- end
- end
- Vk=ones(iters+1,1);
- for k=1:iters+1
- Vk(k)= path(k,:)*path(k,:)';
- end
- loglog(Vk,'--')
- grid on;
- hold on;
- % theta=0.9 ===================================================================
- init=[1,3];
- iters=1000000;
- path=ones(iters+1,length(init));
- path(1,:)=init;
- root=init;
- randn('seed',1)
- for k =1:iters
- if k<2
- xi=randn(1,lth)
- a=1/(k+2)^(2/3);
- root=root-a*gobj(root,xi);
- path(k+1,:)=root;
- else
- xi=randn(1,lth);
- a=1/(k+2)^(2/3);
- v=root-a*gobj(root,xi);
- path(k+1,:)=v;
- theta=0.5;
- th=theta;
- root=(1+th)*path(k+1,:)-theta*path(k,:);
- end
- end
- Vk=ones(iters+1,1);
- for k=1:iters+1
- Vk(k)= path(k,:)*path(k,:)';
- end
- loglog(Vk,'--')
- grid on;
- hold on;
- legend('(k-1)/(k+2)','0.99','0.9','0.5')
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
|