Adam

ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION

解决了什么问题

基于传统随机梯度下降算法的一阶优化算法

贡献点

  1. 结合了AdaGrad(自适应梯度算法)RMSProp(均方根传播算法)两种梯度优化算法的优点

优势

  1. 实现简单,计算高效,对内存需求少

  2. 参数的更新不受梯度的伸缩变换影响

  3. 超参数具有很好的解释性,且通常无需调整或仅需很少的微调

  4. 更新的步长能够被限制在大致的范围内(初始学习率)

  5. 能自然地实现步长退火过程(自动调整学习率)

  6. 很适合应用于大规模的数据及参数的场景

  7. 适用于不稳定目标函数

  8. 适用于梯度稀疏或梯度存在很大噪声的问题

缺点

  1. 在某些极端情况下不会收敛到最优值
  2. Adam 可能会遇到权重衰减问题

方法

Adam Algorithm

参数:

A B C
1 mtmt 一阶矩向量(the mean)
2 vtvt 二阶矩向量(the uncentered variance)
3 t 迭代数
4 θtθt 参数向量
5 ft(θ)ft(θ) 随机目标函数
6 gtgt 梯度
7 β1,β2β1,β2 矩估计的指数衰减率
8 αα 步长, 算法中通过αt=α1βt1αt=α1βt1更新
9 ^mt^mt 偏差修正的一阶矩估计
10 ^vt^vt 偏差修正的二阶原始矩估计

数据集

  1. MNIST
  2. IMDb Movie Reviews
  3. CIFAR10

效果

  1. LOGISTIC REGRESSION Figure 1
  2. MULTI-LAYER NEURAL NETWORKS Figure 2
  3. CONVOLUTIONAL NEURAL NETWORKS Figure 3
  4. BIAS-CORRECTION TERM Figure 4

收敛性证明

ft(θ)ft(θ)convex函数时,判定指标选择为统计量R(T)(Regret):

R(T)=Ti=1[ft(θt)ft(θ)]R(T)=Ti=1[ft(θt)ft(θ)]

其中 θ=argminθχTi=1ft(θ)θ=argminθχTi=1ft(θ)

方法: 求取 R(T) 的一个上界,然后利用上界值除以 T 在极限情况下是否趋于0来判定收敛性。

R(T)T=O(1T)R(T)T=O(1T)

假设:

  1. gtft(θt)gtft(θt),其中iRtiRt表示第i维从1到t的梯度向量,表示为g1:t,i=[g1,i,g2,igt,i]g1:t,i=[g1,i,g2,igt,i]
  2. γβ21β2<1,β1,β2[0,1)γβ21β2<1,β1,β2[0,1)
  3. β1,t=β1λt1,λ(0,1)β1,tβ1β1,t=β1λt1,λ(0,1)β1,tβ1
  4. αt=αtαt=αt

平均不等式:

x1,x2xnN,0<n1/x1+1/x2++1/xnnx1x2xnx1+x2++xnnx21+x22++x2nn0<nni=11xinni=1xini=1xinni=1x2inx1,x2xnN,0<n1/x1+1/x2++1/xnnx1x2xnx1+x2++xnnx21+x22++x2nn0<nni=11xin ni=1xini=1xinni=1x2in(0)

根据定义, 如果f是一个凸函数, 则对于所有的x,yRd,λ[0,1]x,yRd,λ[0,1]都有 λf(x)+(1λ)f(y)f(λx+(1λ)y)f(y)f(x)+f(x)T(yx)λf(x)+(1λ)f(y)f(λx+(1λ)y)f(y)f(x)+f(x)T(yx)(1)(2)

由(2)式可得,

ft(θ)ft(θt)+gTt(θθt)ft(θt)ft(θ)gTt(θtθ)R(T)Tt=1gTt,θtθ=di=1Tt=1gt,i,θt,iθ,ift(θ)ft(θt)+gTt(θθt)ft(θt)ft(θ)gTt(θtθ)R(T)Tt=1gTt,θtθ=di=1Tt=1gt,i,θt,iθ,i(3)

对于算法1更新参数:

θt+1=θtαtˆmt^vt=θtαt1βt1(β1,tmt1+(1β1,t)gt^vt)θt+1=θtαt^mt^vt=θtαt1βt1(β1,tmt1+(1β1,t)gt^vt)

对两边减去参数theta然后平方可得:

(θt+1,iθ,i)2=(θt,iθ,iαtˆmt,i^vt,i)2=(θt,iθ,i)2+(αtˆmt,i^vt,i)22αtˆmt,i^vt,i(θt,iθ,i)=(θt,iθ,i)2+(αtˆmt,i^vt,i)22αt1βt1(β1,tmt1+(1β1,t)gt,i^vt,i)(θt,iθ,i)(θt+1,iθ,i)2=(θt,iθ,iαt^mt,i^vt,i)2=(θt,iθ,i)2+(αt^mt,i^vt,i)22αt^mt,i^vt,i(θt,iθ,i)=(θt,iθ,i)2+(αt^mt,i^vt,i)22αt1βt1(β1,tmt1+(1β1,t)gt,i^vt,i)(θt,iθ,i)

对上式进行分离缩放:

gt,i(θt,iθ,i)=ˆvt,i(1βt1)((θt+1,iθ,i)2(θt,iθ,i)2(αt^mt,i^vt,i)2)2αtβ1,tmt1,i(θt,iθ,i)1β1,t=ˆvt,i(1βt1)((θt+1,iθ,i)2(θt,iθ,i)2(αt^mt,i^vt,i)2)2αt(1β1,t)β1,tmt1,i(θt,iθ,i)1β1,t=ˆvt,i(1βt1)((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αtˆvt,i(1βt1)2(1β1,t)(^mt,i^vt,i)2β1,tmt1,i(θt,iθ,i)1β1,t=ˆvt,i(1βt1)((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αtˆvt,i(1βt1)2(1β1,t)(^mt,i^vt,i)2+β1,tmt1,i(θ,iθt,i)ˆv14t1,iαt1αt1ˆv14t1,i1β1,tˆvt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)ˆm2t,i^vt,i+β1,t(θ,iθt,i)ˆv14t1,iαt1mt1,iαt1ˆv14t1,i1β1,tˆvt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)ˆm2t,i^vt,i+β1,tˆvt1,i2αt1(1β1,t)(θ,iθt,i)2+β1,tm2t1,iαt12(1β1,t)ˆvt1,igt,i(θt,iθ,i)=^vt,i(1βt1)((θt+1,iθ,i)2(θt,iθ,i)2(αt^mt,i^vt,i)2)2αtβ1,tmt1,i(θt,iθ,i)1β1,t=^vt,i(1βt1)((θt+1,iθ,i)2(θt,iθ,i)2(αt^mt,i^vt,i)2)2αt(1β1,t)β1,tmt1,i(θt,iθ,i)1β1,t=^vt,i(1βt1)((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt^vt,i(1βt1)2(1β1,t)(^mt,i^vt,i)2β1,tmt1,i(θt,iθ,i)1β1,t=^vt,i(1βt1)((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt^vt,i(1βt1)2(1β1,t)(^mt,i^vt,i)2+β1,tmt1,i(θ,iθt,i)^v14t1,iαt1αt1^v14t1,i1β1,t^vt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)^m2t,i^vt,i+β1,t(θ,iθt,i)^v14t1,iαt1mt1,iαt1^v14t1,i1β1,t^vt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)^m2t,i^vt,i+β1,t^vt1,i2αt1(1β1,t)(θ,iθt,i)2+β1,tm2t1,iαt12(1β1,t)^vt1,i(4)

注释: 根据杨氏不等式aba22+b22aba22+b22,

其中

a=(θ,iθt,i)ˆv14t1,iαt1b=mt1,iαt1ˆv14t1,ia=(θ,iθt,i)^v14t1,iαt1b=mt1,iαt1^v14t1,i

把(4)式带入(3)式可得,

R(T)di=1Tt=1[ˆvt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)ˆm2t,i^vt,i+β1,tˆvt1,i2αt1(1β1,t)(θ,iθt1,i)2+β1,tm2t1,iαt12(1β1,t)ˆvt1,i]R(T)di=1Tt=1[^vt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)^m2t,i^vt,i+β1,t^vt1,i2αt1(1β1,t)(θ,iθt1,i)2+β1,tm2t1,iαt12(1β1,t)^vt1,i]

其中

Tt=1ˆm2t,itˆvt,i21γG1β2g1:T,i2Tt=1^m2t,it^vt,i21γG1β2g1:T,i2(5)

证明:

假设: (1βt2)(1βt1)21(1β1)2(1βt2)(1βt1)21(1β1)2(6)

参数有界:

θnθm2DθnθmDθnθm2DθnθmD(7)

其中,

ˆvt,i=tj=1(1β2)βtj2g2j,i1βt2g1:t,i2ˆm2t,i=(β1mt1,i+(1β1)gt,i)2(1βt1)2=(tk=1βtk1(1β1)gk,i)2(1βt1)2g1:t,i22^vt,i=tj=1(1β2)βtj2g2j,i1βt2g1:t,i2^m2t,i=(β1mt1,i+(1β1)gt,i)2(1βt1)2=(tk=1βtk1(1β1)gk,i)2(1βt1)2g1:t,i22(8)(9)

展开不等式(5)左边的最后一项:

Tt=1ˆm2t,itˆvt,i=T1t=1ˆm2t,itˆvt,i+ˆm2T,iTˆvT,i=T1t=1ˆm2t,itˆvt,i+m2T,i(1βT1)2TvT,i1βT2=T1t=1ˆm2t,itˆvt,i+(β1mT1,i+(1β1)gT,i)2(1βT1)2Tβ2vT1,i+(1β2)g2T,i1βT2=T1t=1ˆm2t,itˆvt,i+1βT2(1βT1)2(Tk=1(1β1)βTk1gk,i)2TTj=1(1β2)βTj2g2j,iT1t=1ˆm2t,itˆvt,i+1βT2(1βT1)2Tk=1T((1β1)βTk1gk,i)2TTj=1(1β2)βTj2g2j,iT1t=1ˆm2t,itˆvt,i+1βT2(1βT1)2Tk=1T((1β1)βTk1gk,i)2T(1β2)βTk2g2k,iT1t=1ˆm2t,itˆvt,i+1βT2(1βT1)2(1β1)2T(1β2)Tk=1T(β21β2)Tkgk,i2T1t=1ˆm2t,itˆvt,i+TT(1β2)Tk=1γTkgk,i2Tt=1ˆm2t,itˆvt,iTt=1gt,i2t(1β2)Ttj=0tγjTt=1gt,i2t(1β2)Tj=0tγjTt=1^m2t,it^vt,i=T1t=1^m2t,it^vt,i+^m2T,iT^vT,i=T1t=1^m2t,it^vt,i+m2T,i(1βT1)2TvT,i1βT2=T1t=1^m2t,it^vt,i+(β1mT1,i+(1β1)gT,i)2(1βT1)2Tβ2vT1,i+(1β2)g2T,i1βT2=T1t=1^m2t,it^vt,i+1βT2(1βT1)2(Tk=1(1β1)βTk1gk,i)2TTj=1(1β2)βTj2g2j,iT1t=1^m2t,it^vt,i+1βT2(1βT1)2Tk=1T((1β1)βTk1gk,i)2TTj=1(1β2)βTj2g2j,iT1t=1^m2t,it^vt,i+1βT2(1βT1)2Tk=1T((1β1)βTk1gk,i)2T(1β2)βTk2g2k,iT1t=1^m2t,it^vt,i+1βT2(1βT1)2(1β1)2T(1β2)Tk=1T(β21β2)Tkgk,i2T1t=1^m2t,it^vt,i+TT(1β2)Tk=1γTkgk,i2Tt=1^m2t,it^vt,iTt=1gt,i2t(1β2)Ttj=0tγjTt=1gt,i2t(1β2)Tj=0tγj((0))((6)2)(10)

根据等比数列求和公式:

Sn=a(1rn)1r,r1S=a1r,1<r<1Tj=0tγj=t1γT1γ<t1γSn=a(1rn)1r,r1S=a1r,1<r<1Tj=0tγj=t1γT1γ<t1γ(11)

把(11)式带入(10)式得,

Tt=1gt,i2t(1β2)Tj=0tγj1(1γ)1β2Tt=1gt,i2tTt=1gt,i2t(1β2)Tj=0tγj1(1γ)1β2Tt=1gt,i2t(12)

梯度有界:

gt2G,gtGTt=1g2t,it2Gg1:T,i2gt2G,gtGTt=1g2t,it2Gg1:T,i2(13)

运用数学归纳法证明:

当T=1时, g21,i2Gg1:T,i2g21,i2Gg1:T,i2
当T=T时, Tt=1g2t,it=T1t=1g2t,it+g2T,iT2Gg1:T1,i2+g2T,iT=2Gg1:T,i22g2T+g2T,iTTt=1g2t,it=T1t=1g2t,it+g2T,iT2Gg1:T1,i2+g2T,iT=2Gg1:T,i22g2T+g2T,iT(14)

g1:T,i22g2T+g4T,i4g1:T,i22g1:T,i22g2Tg1:T,i22g2T+g4T,i4g1:T,i22g1:T,i22g2T

两边开根号得,

g1:T,i22g2Tg1:T,i2g2T,i2g1:T,i2g1:T,i2g2T,i2TG2g1:T,i22g2Tg1:T,i2g2T,i2g1:T,i2g1:T,i2g2T,i2TG2

对不等号左边换算成公式(14):

Gg1:T,i22g2TGg1:T,i2g2T,i2T2Gg1:T,i22g2T2Gg1:T,i2g2T,iT2Gg1:T,i22g2T+g2T,iT2Gg1:T,i2g2T,iT+g2T,iT2Gg1:T,i22g2T+g2T,iT2Gg1:T,i2g2T,igT,iT2Gg1:T,i2Tt=1g2t,it2Gg1:T,i2Gg1:T,i22g2TGg1:T,i2g2T,i2T2Gg1:T,i22g2T2Gg1:T,i2g2T,iT2Gg1:T,i22g2T+g2T,iT2Gg1:T,i2g2T,iT+g2T,iT2Gg1:T,i22g2T+g2T,iT2Gg1:T,i2g2T,igT,iT2Gg1:T,i2Tt=1g2t,it2Gg1:T,i2

把公式(11)带入公式(10)得,

Tt=1ˆm2t,itˆvt,i2G(1γ)1β2g1:T,i2Tt=1^m2t,it^vt,i2G(1γ)1β2g1:T,i2(15)

所以,

R(T)di=1Tt=1[ˆvt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)ˆm2t,i^vt,i+β1,tˆvt1,i2αt1(1β1,t)(θ,iθt,i)2+β1,tm2t1,iαt12(1β1,t)ˆvt1,i]di=1Tt=1[ˆvt,i(θt,iθ,i)22αt(1β1,t)ˆvt,i(θt+1,iθ,i)22αt(1β1,t)+β1,tˆvt1,i2αt1(1β1,t)(θ,iθt,i)2+αt2(1β1,t)ˆm2t,i^vt,i+β1,tm2t1,iαt12(1β1,t)ˆvt1,i]di=1[ˆv1,i(θ1,iθ,i)22α1(1β1,1)ˆvT,i(θT+1,iθ,i)22αT(1β1,T)]+di=1Tt=2[ˆvt,i(θt,iθ,i)22αt(1β1,t)ˆvt1,i(θt,iθ,i)22αt1(1β1,t1)]+di=1Tt=1[β1,tˆvt1,i2αt1(1β1,t)(θ,iθt,i)2+αt2(1β1,t)ˆm2t,i^vt,i+β1,tαt12(1β1,t)m2t1,iˆvt1,i]di=1[ˆv1,i(θ1,iθ,i)22α1(1β1,1)ˆvT,i(θT+1,iθ,i)22αT(1β1,T)]+12(1β1)di=1Tt=2(θt,iθ,i)2(ˆvt,iαtˆvt1,iαt1)+di=1Tt=1[β1,tˆvt1,i2αt1(1β1,t)(θ,iθt,i)2]+α2(1β1)2G(1γ)1β2di=1g1:T,i2+αβ12(1β1)2G(1γ)1β2di=1g1:T,i2di=1[ˆv1,iD22α1(1β1,1)ˆvT,iD22αT(1β1,T)]+D22α1(1β1)di=1Tt=2(tˆvt,i(t1)ˆvt1,i)+D22α1di=1[β1,t(t1)ˆvt1,i(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2di=1[ˆv1,iD22α1(1β1,1)ˆvT,iD22αT(1β1,T)]+D22α1(1β1)di=1(TˆvT,iˆv1,i)+D2G2α1di=1[β1,t(t1)(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2R(T)di=1Tt=1[^vt,i((θt,iθ,i)2(θt+1,iθ,i)2)2αt(1β1,t)+αt2(1β1,t)^m2t,i^vt,i+β1,t^vt1,i2αt1(1β1,t)(θ,iθt,i)2+β1,tm2t1,iαt12(1β1,t)^vt1,i]di=1Tt=1[^vt,i(θt,iθ,i)22αt(1β1,t)^vt,i(θt+1,iθ,i)22αt(1β1,t)+β1,t^vt1,i2αt1(1β1,t)(θ,iθt,i)2+αt2(1β1,t)^m2t,i^vt,i+β1,tm2t1,iαt12(1β1,t)^vt1,i]di=1[^v1,i(θ1,iθ,i)22α1(1β1,1)^vT,i(θT+1,iθ,i)22αT(1β1,T)]+di=1Tt=2[^vt,i(θt,iθ,i)22αt(1β1,t)^vt1,i(θt,iθ,i)22αt1(1β1,t1)]+di=1Tt=1[β1,t^vt1,i2αt1(1β1,t)(θ,iθt,i)2+αt2(1β1,t)^m2t,i^vt,i+β1,tαt12(1β1,t)m2t1,i^vt1,i]di=1[^v1,i(θ1,iθ,i)22α1(1β1,1)^vT,i(θT+1,iθ,i)22αT(1β1,T)]+12(1β1)di=1Tt=2(θt,iθ,i)2(^vt,iαt^vt1,iαt1)+di=1Tt=1[β1,t^vt1,i2αt1(1β1,t)(θ,iθt,i)2]+α2(1β1)2G(1γ)1β2di=1g1:T,i2+αβ12(1β1)2G(1γ)1β2di=1g1:T,i2di=1[^v1,iD22α1(1β1,1)^vT,iD22αT(1β1,T)]+D22α1(1β1)di=1Tt=2(t^vt,i(t1)^vt1,i)+D22α1di=1[β1,t(t1)^vt1,i(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2di=1[^v1,iD22α1(1β1,1)^vT,iD22αT(1β1,T)]+D22α1(1β1)di=1(T^vT,i^v1,i)+D2G2α1di=1[β1,t(t1)(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2((12)4)((7))((8))

其中, di=1[β1,t(t1)(1β1,t)]di=111β1λt1t1di=111β1λt1(t1)11β1λ(1λ)2di=1[β1,t(t1)(1β1,t)]di=111β1λt1t1di=111β1λt1(t1)11β1λ(1λ)2(16)

注释:

n0nxn=n0x(xn)=xn0(xn)=x(n0(xn))=x(11x)=x(1x)2n0nxn=n0x(xn)=xn0(xn)=x(n0(xn))=x(11x)=x(1x)2(,nx^{n-1})()(|x|<1)

最后,

R(T)di=1[ˆv1,iD22α1(1β1,1)ˆvT,iD22αT(1β1,T)]+D22α1(1β1)di=1(TˆvT,iˆv1,i)+D2G2α1di=1[β1,t(t1)(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2di=1[ˆv1,iD22α1(1β1,1)TˆvT,iD22α1(1β1,T)]+D22α1(1β1)di=1(TˆvT,iˆv1,i)+di=1λD2G2α1(1β1)(1λ)2+α1(1+β1)G(1γ)1β2(1β1)di=1g1:T,i2R(T)di=1[^v1,iD22α1(1β1,1)^vT,iD22αT(1β1,T)]+D22α1(1β1)di=1(T^vT,i^v1,i)+D2G2α1di=1[β1,t(t1)(1β1,t)]+α1(1+β1)(1β1)G(1γ)1β2di=1g1:T,i2di=1[^v1,iD22α1(1β1,1)T^vT,iD22α1(1β1,T)]+D22α1(1β1)di=1(T^vT,i^v1,i)+di=1λD2G2α1(1β1)(1λ)2+α1(1+β1)G(1γ)1β2(1β1)di=1g1:T,i2((16))

改进方向

ADAMAX

ADAMAX

AdamW

AdamW

AdamX

AdamX

Code精读

Code

m, v = exp_avg, exp_avg_sq

mtβ1mt1+(1β1)gtmtβ1mt1+(1β1)gt
vtβ2vt1+(1β2)g2tvtβ2vt1+(1β2)g2t

m.mul_(beta1).add_(grad, alpha=1 - beta1)
v.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)

偏差修正

bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step

更新参数向量θtθt

θtθt1αˆmt/(ˆvt+ϵ)θtθt1αmt(1βt1)/(vt(1βt2)+ϵ)θtθt1α^mt/(^vt+ϵ)θtθt1αmt(1βt1)/(vt(1βt2)+ϵ)
param.addcdiv_(m, denom, value=-step_size)

其中分母vt(1βt2)+ϵvt(1βt2)+ϵ

denom = (v.sqrt() / math.sqrt(bias_correction2)).add_(eps)

学习率αt=α(1βt1)αt=α(1βt1)

step_size = lr / bias_correction1

Adam of pytorch

input:γ(lr),β1,β2(betas),θ0(params),f(θ)(objective),ϵ (epsilon)λ(weight decay),amsgrad,maximizeinitialize:m00 (first moment),v00 ( second moment),^v0max0fort=1todoifmaximize:gtθft(θt1)elsegtθft(θt1)θtθt1γλθt1mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)g2t^mtmt/(1βt1)^vtvt/(1βt2)ifamsgrad^vtmaxmax(^vtmax,^vt)θtθtγ^mt/(^vtmax+ϵ)elseθtθtγ^mt/(^vt+ϵ)returnθtinput:γ(lr),β1,β2(betas),θ0(params),f(θ)(objective),ϵ (epsilon)λ(weight decay),amsgrad,maximizeinitialize:m00 (first moment),v00 ( second moment),ˆv0max0fort=1todoifmaximize:gtθft(θt1)elsegtθft(θt1)θtθt1γλθt1mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)g2tˆmtmt/(1βt1)ˆvtvt/(1βt2)ifamsgradˆvtmaxmax(ˆvtmax,ˆvt)θtθtγˆmt/(ˆvtmax+ϵ)elseθtθtγˆmt/(ˆvt+ϵ)returnθt

扩展

  1. 随机梯度下降算法发展史: SGD -> SGDM -> NAG ->AdaGrad -> AdaDelta -> RMSprop -> Adam -> AdaMax -> Nadam
  2. An overview of gradient descent optimization algorithms∗
  3. ON THE CONVERGENCE PROOF OF AMSGRAD AND A NEW VERSION?