余弦退火调整学习率 CosineAnnealingLR




使用梯度下降算法优化目标函数时,当接近损失函数的全局最小值时,学习率应该变得更小以使得模型尽可能接近这一点,所以需要对学习率进行衰减。

余弦函数的特点是,随着自变量 x 的增大,余弦函数值先缓慢下降,然后加速下降,再减速下降,所以常用余弦函数来降低学习率称之为余弦退火(Cosine Annealing)对于每个周期都会按如下公式进行学习率的衰减工作

由于刚开始训练时,模型的权重是随机初始化的,此时若选择一个较大的学习率,模型可能会出现振荡现象。利用训练预热 (Warmup) 学习率的方法,使得前几个周期内的学习率较小,在较小的学习率的预热下模型将逐步趋于稳定,当模型较为稳定后便使用预先设置的学习率进行训练,这有利于加快模型的收敛速度,模型效果更佳



CosineAnnealingLR


这个比较简单,只对其中的最关键的 Tmax 参数作一个说明, 这个可以理解为余弦函数的半周期. 如果 max_epoch=50 次,那么设置 Tmax=5 则会让学习率余弦周期性变化 5 次.




CosineAnnealingWarmRestarts

class torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1)

  • optimizer:优化器
  • T_0:重启后的迭代次数
  • T_mult:每次重启后增加迭代次数的乘法因子。默认为 1
  • eta_min:最小学习率。默认为 0
  • last_epoch:最新一轮。默认为 -1

其实现公式如下:


这个最主要的参数有两个:

  • T_0: 学习率第一次回到初始值的 epoch 位置
  • T_mult: 这个控制了学习率变化的速度, 重启之后因子,默认是 1
    • 每次 restart 后,T_0 = T_0 * T_mult
    • 如果 T_mult=1, 则学习率在 T_0, 2T_0, 3T_0, ...., i*T_0, .... 处回到最大值(初始学习率)
      • 5, 10, 15, 20, 25, ....... 处回到最大值
    • 如果 T_mult>1, 则学习率在 T_0, (1+T_mult)T_0,(1+T_mult+T_mult**2)T_0, ....., (1+T_mult+T_mult2+...+T_0i)*T0, 处回到最大值
      • 5, 15, 35, 75, 155, ....... 处回到最大值





所以可以看到,在调节参数的时候,一定要根据自己总的 epoch 合理的设置参数,不然很可能达不到预期的效果, 经过我自己的试验发现,如果是用那种等间隔的退火策略 (CosineAnnealingLR 和Tmult=1 的 CosineAnnealingWarmRestarts),验证准确率总是会在学习率的最低点达到一个很好的效果,而随着学习率回升,验证精度会有所下降. 所以为了能最终得到一个更好的收敛点,设置 T_mult>1 是很有必要的,这样到了训练后期,学习率不会再有一个回升的过程, 而且一直下降直到训练结束。



使用示例


下面是使用示例和画图的代码:

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR
import torch.nn as nn
from torchvision.models import resnet18
import matplotlib.pyplot as plt

model = resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
mode = 'cosineAnnWarm'
if mode == 'cosineAnn':
    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)
elif mode == 'cosineAnnWarm':
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1)
    '''
    以 T_0=5, T_mult=1 为例:
    T_0:学习率第一次回到初始值的 epoch 位置.
    T_mult: 这个控制了学习率回升的速度
        - 如果 T_mult=1, 则学习率在 T_0, 2*T_0, 3*T_0, ...., i*T_0, .... 处回到最大值(初始学习率)
            - 5,10,15,20,25,.......处回到最大值
        - 如果 T_mult>1, 则学习率在 T_0, (1+T_mult)*T_0, (1+T_mult+T_mult**2)*T_0, ....., (1+T_mult+T_mult**2+...+T_0**i)*T0, 处回到最大值
            - 5,15,35,75,155,.......处回到最大值
    example:
        T_0=5, T_mult=1
    '''
plt.figure()
max_epoch = 50
iters = 200
cur_lr_list = []
for epoch in range(max_epoch):
    for batch in range(iters):
        '''
        这里 scheduler.step(epoch + batch / iters) 的理解如下, 如果是一个 epoch 结束后再 .step
        那么一个 epoch 内所有 batch 使用的都是同一个学习率, 为了使得不同 batch 也使用不同的学习率
        则可以在这里进行 .step
        '''
        #scheduler.step(epoch + batch / iters)
        optimizer.step()
    scheduler.step()
    cur_lr = optimizer.param_groups[-1]['lr']
    cur_lr_list.append(cur_lr)
    print('cur_lr:',cur_lr)
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR
import torch.nn as nn
from torchvision.models import resnet18
import matplotlib.pyplot as plt

model = resnet18(pretrained=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
mode = 'cosineAnnWarm'
if mode == 'cosineAnn':
    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)
elif mode=='cosineAnnWarm':
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2)
plt.figure()
max_epoch = 20
iters = 5
cur_lr_list = []
for epoch in range(max_epoch):
    print('epoch_{}'.format(epoch))
    for batch in range(iters):
        scheduler.step(epoch + batch / iters)
        optimizer.step()
        # scheduler.step()
        cur_lr = optimizer.param_groups[-1]['lr']
        cur_lr_list.append(cur_lr)
        print('cur_lr:', cur_lr)
    print('epoch_{}_end'.format(epoch))
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()

最后, 对 scheduler.step(epoch + batch / iters) 的一个说明,这里的个人理解: 一个 epoch 结束后再 .step, 那么一个 epoch 内所有 batch 使用的都是同一个学习率, 为了使得不同 batch 也使用不同的学习率 , 则可以在这里进行 .step (将离散连续化, 或者说使得采样得更加的密集),下图是以 20 个 epoch,每个 epoch 5个 batch,T0=2, Tmul=2 画的学习率变化图







reference

https://zhuanlan.zhihu.com/p/261134624