您的位置 首页 PyTorch 教程

BatchNorm 到底应该怎么用?

PyTorch入门实战教程

BatchNorm 最初是在 2015 年这篇论文中提出的,论文指出,BatchNorm 主要具有以下特性:

  • 更快的训练速度:由于 BatchNorm 的权重分布差异很小(论文中称为 internal covariate shift),我们可以使用更高的学习率来训练网络,让我们朝向损失函数最小的方向前进。
  • 改进网络正则化(Regularization):通过 BatchNorm 可以使网络在训练的时候,每个 batch 里的数据规范化都是不一样的,有助于减少网络过拟合。
  • 提高准确率:由于上述两点,整体上可以提高准确率。

BatchNorm 的原理

BatchNorm 的基本原理可以用一句话总结:保证网络每次接受的输入都是均值 0 标准差 1(mean 0 and a standard deviation of 1),算法原理如图所示:

下面是一个用纯 PyTorch 实现的 BatchNorm 算法过程:

需要注意的是,BatchNorm 在训练阶段和测试阶段的行为是不一样的,要通过切换 model 的模式来合理应用。

什么时候应该用 BatchNorm?

很多实验证明,BatchNorm 只要用了就有效果,所以一般情况下没有理由不用。

用的地方通常在一个全连接或者卷积层与激活函数中间,即 (全连接/卷积)—- BatchNorm —- 激活函数。但也有人说把 BatchNorm 放在激活函数后面效果更好,可以都试一下。

什么时候不应该用 BatchNorm?

当每个 batch 里所有的 sample 都非常相似的时候,相似到 mean 和 variance 都基本为 0,则最好不要用 BatchNorm。

当然,还有一种情况,如果 batch size 为 1,从原理上来讲,用 BatchNorm 是没有任何意义的。

特别注意:Transfer Learning

通常我们在进行 Transfer Learning 的时候,会冻结之前的网络权重,注意这时候往往也会冻结 BatchNorm 中训练好的 moving averages 值。这写 moving averages 值只适用于以前的旧的数据,对新数据不一定适用。所以最好的方法是在 Transfer Learning 的时候不要冻结 BatchNorm 层,让 moving averages 重新从新的数据中学习。

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部