PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

我们描述了最近提出的随机加权平均(SWA)技术,以及它在 torchcontrib中的新实现。

导读

前不久PyTorch发了一篇官方博客,就是这篇SWA的文章,在torchcontrib中实现了SWA,从此以后,SWA也可以直接用了,可以在不增加推理时间的情况下,提高泛化能力,而且用起来非常简单,还不来试试!

在这篇博文中,我们描述了最近提出的随机加权平均(SWA)技术,以及它在 torchcontrib中的新实现。SWA是一个简单的过程,它可以在不增加任何额外成本的情况下改进深度学习中对随机梯度下降(SGD)的泛化,并且可以作为PyTorch中任何其他优化器的替代。SWA有广泛的应用和功能:

  1. SWA已被证明可以显著提高计算机视觉任务的泛化能力,包括VGG、ResNets、Wide ResNets和DenseNets对ImageNet和CIFAR基准测试的泛化能力。
  2. SWA在半监督学习和领域适应的关键基准上提供了最先进的性能。
  3. 研究表明,SWA能够提高训练的稳定性,提高深度强化学习中策略梯度方法的最终平均奖励。
  4. 在深度学习中,对SWA的扩展可以得到高效的贝叶斯模型平均,以及高质量的对不确定性的估计和标定。
  5. 对于低精度训练,SWALP可以媲美全精度SGD的性能,即使所有数字量化到8位,包括使用梯度累加器。

简而言之,SWA使用一个改变的学习率策略对SGD路径上的权值进行平均(参见图1的左图)。SWA的解最终停在一个宽的平坦损失区域的中心,而SGD趋向于收敛到低损失区域的边界,使得它容易受到训练和测试误差面之间的偏移的影响(见图1的中、右图)。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

在CIFAR-100上用预激活的ResNet-164,采用SWA和SGD

左:三个FGE样本及其对应的SWA的解的测试误差曲面(权值空间平均)。中和右:测试误差和训练损失面,显示SGD(收敛时)和SWA提出的权值,使用和SGD相同的初始化,125个训练阶段后。

使用我们在torchcontrib中的实现来使用SWA使用PyTorch中的任何其他优化器一样简单:

from torchcontrib.optim import SWA
...
...
# training loop
base_opt = torch.optim.SGD(model.parameters, lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):
opt.zero_grad
loss_fn(model(input), target).backward
opt.step
opt.swap_swa_sgd

你可以使用 torch.optim中的SWA类来包装所有的优化器,然后像往常一样训练你的模型。当训练完成时,只需调用swap_swa_sgd来将模型的权重设置为它们的SWA平均值。下面我们将详细解释SWA过程和SWA类的参数。我们强调,SWA可以与任何优化过程相结合,比如Adam,就像它可以与SGD相结合一样。

这个就只是Averaged SGD吗?

在较高的层次上,Averaged SGD迭代可以追溯到几十年前的凸优化,在凸优化中有时被称为Polyak-Ruppert平均,或averaged SGD。但是细节很重要。Averaged SGD通常与衰减学习率和指数移动平均一起使用,通常用于凸优化。在凸优化中,提高收敛速度一直是研究的重点。在深度学习中,这种形式的averaged SGD平滑了SGD迭代的轨迹,但执行起来并没有太大的不同。

相比之下,SWA关注的是一个equal average的SGD迭代,使用一个变化的循环或高恒定学习率,并利用针对深度学习的训练目标的平坦性来改进泛化能力。

随机权值平均

SWA的工作有两个重要的组成部分。首先,SWA使用一个变化的学习率策略,这样SGD就可以继续探索一组高性能的网络,而不是简单地收敛到一个单一的解决方案。例如,我们可以对前75%的训练时间使用标准衰减学习率策略,然后将剩余25%的时间的学习率设置为一个相当高的恒定值(参见下面的图2)。第二步是对SGD遍历的网络的权值进行平均。例如,我们可以在最后25%的训练时间内保持每个epoch结束时获得的权重的运行状态平均值(参见图2)。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

图解SWA采用的学习率策略

前75%的训练采用标准衰减策略,其余25%采用大的恒定值。SWA平均值是在最后25%的训练中形成的。

在我们的实现中, SWA优化器的自动模式允许我们运行上面描述的过程。要在自动模式下运行SWA,只需将优化器的base_opt(可以是SGD、Adam或任何其他torch.optim.Optimizer)与SWA(base_opt,swa_start,swa_freq,swa_lr)封装在一起。在swa_freq优化步骤之后,学习率将切换到一个常数值swa_lr,在每个swa_freq优化步骤的末尾,权重快照将添加到SWA运行平均值中。一旦运行opt.swap_swa_sgd,模型的权重就会被它们的SWA运行平均值所替代。

BATCH NORMALIZATION

要记住的一个重要细节是batch normalization。batch normalization层计算训练期间活动的运行统计数据。注意,权值的SWA平均值从未用于在训练期间进行预测,因此在使用 opt.swap_swa_sgd重置模型的权值后,batch normalization层没有计算激活统计信息。要计算激活统计数据,只需在训练结束后使用SWA模型对训练数据进行正向传递。在SWA类中,我们提供了一个辅助函数opt.bn_update(train_loader,model)。它通过向前传递train_loader数据加载器来更新模型中每个批处理规范化层的激活统计信息。你只需要在训练结束时调用这个函数一次。

高级学习率策略

SWA可以与任何有助于探索解决方案落在平坦区域的学习率策略一起使用。例如,你可以在最后25%的训练时间内使用循环学习率,而不是一个固定值,并对每个周期内学习率最低值对应的网络的权重进行平均(参见图3)。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

图解SWA与另一种学习率策略一起使用

在最后25%的训练中采用循环学习率,并在每个周期结束时对模型参数进行平均。

在我们的实现中,你可以通过在手动模式下使用 SWA实现自定义学习率和权重平均策略。下面的代码相当于这篇博客文章开头介绍的自动模式代码。

opt = torchcontrib.optim.SWA(base_opt)
for i in range(100):
opt.zero_grad
loss_fn(model(input), target).backward
opt.step
if i > 10 and i % 5 == 0:
opt.update_swa
opt.swap_swa_sgd

在手动模式下,你不需要指定 swa_start,swa_lr和swa_freq,只要在需要更新SWA运行平均值时调用opt.update_swa(例如在每个学习周期的末尾)。在手动模式下,SWA不会改变学习速度,所以你可以使用任何学习率策略,就像你通常使用任何其他torch.optim.Optimizer一样。

为什么这样可以工作?

SGD收敛于一个宽平坦区域内的解。权值空间具有极高的高维性,且大部分平坦区域的体积都集中在边界附近,所以SGD的解总是在损失平坦区域的边界附近找到。另一方面,SWA平均多个SGD的解,这使得它可以移动到平坦区域的中心。

我们期望在损失平坦区域中心的解比边界附近的解更具泛化性。实际上,训练和测试误差面在权值空间中并不是完全对齐的。以平面为中心的解不像边界附近的解那样容易受到训练误差面与测试误差面的影响。在下面的图4中,我们展示了连接SWA和SGD的解的方向上的训练损失和测试误差面。如你所见,虽然SWA方案比SGD方案具有更高的训练损失,但是它集中在低损失区域,并且具有更好的测试误差。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

连接SWA解和SGD解的线路上的训练损失和测试误差

SWA解集中在较宽的训练低损失区域,SGD解位于边界附近。由于训练损失与测试误差曲面之间的位移,SWA解具有较好的泛化效果。

例子和结果

我们发布了一个GitHub repo: (https://github.com/izmailovpavel/contribswaexamples),其中包含了使用SWA的 torchcontrib实现来训练DNNs的例子。例如,这些例子可以在CIFAR-100上实现以下结果:

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

半监督学习

在后续的论文(https://arxiv.org/abs/1806.05594)中,SWA被应用到半监督学习中,在多个设置中,它展示了超出最佳报告结果的改进。例如,如果只有4k训练数据点的训练标签(之前关于这个问题的最佳报告结果是93.7%),那么使用SWA,你可以在CIFAR-10上获得95%的准确率。本文还研究了在给定时间内多次平均的方法,它可以加快收敛速度,并在给定时间内找到平坦

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

基于CIFAR-10的快速swa半监督学习性能

fast-SWA在每个设置中都获得了创纪录的结果。

校准和不确定性估计

SWA-Gaussian (SWAG)是贝叶斯深度学习中用于不确定性估计和校准的一种简单、可扩展和方便的方法。它也使用SGD迭代运行平均值与SWA类似,SWAG估计迭代的第一和第二阶矩,以构造权重上的高斯分布。SWAG分布近似真实后验的形状,下面的图6显示了PreResNet-164在CIFAR-100上的后验对数密度顶部的SWAG分布。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

SWAG分布在CIFAR-100上PreResNet-164的后验对数密度的顶部

SWAG分布的形状与后验分布对齐。

从经验上看,SWAG在计算机视觉任务中的不确定性量化、异常检测、标定和迁移学习等方面的性能与MC dropout、KFAC拉普拉斯、温度缩放等常用方法相当或更好。SWAG的代码在此处(https://github.com/wjmaddox/swa_gaussian)。

强化学习

在另一篇后续的论文中,SWA被证明能够提高策略梯度方法A2C和DDPG在Atari游戏和MuJoCo环境中的性能。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

低精度训练

我们可以通过组合向上舍入权值和向下舍入权值来过滤量化噪声。而且,通过平均权值找到的损失表面的平坦区域,权值的大扰动不会影响解的质量(图7和8)。最近的工作表明,SWA适应低精度设置,有个方法称为SWALP,可以使用8bit训练匹配全精度的SGD的性能。这是一个非常重要的实际结果,因为(1)8bit的SGD训练比全精度SGD训练表现明显要差,(2)低精度训练比训练后使用低精度预测(常用的设置)要困难得多。例如,使用浮点(16位)SGD在CIFAR-100上训练的ResNet-164可以达到22.2%的错误,而8位SGD可以达到24.0%的错误。相比之下,经过8位训练的SWALP的错误率为21.8%。

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

在平坦区域进行量化仍然可以得到低损失的解

PyTorch官方推荐!SWA:不增加推理时间提高泛化能力的集成方法

低精度SGD训练(使用变化的学习率策略),SWALP

总结

深度学习中最大的开放性问题之一是,既然训练目标是高度多模态的,而且原则上有许多参数设置不会造成训练上的变差,但泛化能力会变差,为什么SGD能够找到好的解。通过理解与泛化相关的几何特性(如平面度),我们可以开始解决这些问题,并构建提供更好泛化性能的优化器,以及许多其他有用的特性(如不确定性表示)。我们提出了一种简单的替代标准SGD的SWA,它原则上可以使任何训练深度神经网络的人受益。SWA在计算机视觉、半监督学习、强化学习、不确定性表示、标定、贝叶斯模型平均和低精度训练等方面都有较强的性能。

我们鼓励你尝试SWA!现在使用SWA就像使用PyTorch中的任何其他优化器一样简单。而且,即使你已经使用SGD(或任何其他优化器)对模型进行了训练,也很容易通过从一个预训练的模型开始运行SWA,在少量时间内运行SWA来实现SWA的好处。

作者:Pavel Izmailov and Andrew Gordon Wilson

编译:ronghuaiyang

英文原文:https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

本文为专栏文章,来自:AI公园,内容观点不代表本站立场,如若转载请联系专栏作者,本文链接:https://www.afenxi.com/66262.html 。

(0)
AI公园的头像AI公园专栏
上一篇 2019-10-09 09:04
下一篇 2019-10-10 09:53

相关文章

关注我们
关注我们
分享本页
返回顶部