您的位置 首页 PyTorch 教程

PyTorch 大批量数据在单个或多个 GPU 训练指南

PyTorch入门实战教程

在深度学习训练中,我们经常遇到 GPU 的内存太小的问题,如果我们的数据量比较大,别说大批量(large batch size)训练了,有时候甚至连一个训练样本都放不下。但是随机梯度下降(SGD)中,如果能使用更大的 Batch Size 训练,一般能得到更好的结果。所以问题来了:

问题来了:当 GPU 的内存不够时,如何使用大批量(large batch size)样本来训练神经网络呢?

这篇文章将以 PyTorch 为例,讲解一下几点:

  1. 当 GPU 的内存小于 Batch Size 的训练样本,或者甚至连一个样本都塞不下的时候,怎么用单个或多个 GPU 进行训练?
  2. 怎么尽量高效地利用多 GPU?

单个或多个 GPU 进行大批量训练

如果你也遇到过 CUDA RuntimeError: out of memory 的错误,那么说明你也遇到了这个问题。

PyTorch 的开发人员都出来了,估计一脸黑线:兄弟,这不是 bug,是你内存不够…

又一个方法可以解决这个问题:梯度累加(accumulating gradients)。

一般在 PyTorch 中,我们是这样来更新梯度的:

在上看的代码注释中,在计算梯度的 loss.backward() 操作中,每个参数的梯度被计算出来后,都被存储在各个参数对应的一个张量里:parameter.grad。然后优化器就会根据这个来更新每个参数的值,就是 optimizer.step()

而梯度累加(accumulating gradients)的基本思想就是, 在优化器更新参数前,也就是执行 optimizer.step() 前,我们进行多次梯度计算,保存在 parameter.grad 中,然后累加梯度再更新。这个在 PyTorch 中特别容易实现,因为 PyTorch 中,梯度值本身会保留,除非我们调用 model.zero_grad() or optimizer.zero_grad()

下面是一个梯度累加的例子,其中 accumulation_steps 就是要累加梯度的循环数:

如果连一个样本都不放下怎么办?

如果样本特别大,别说 batch training,要是 GPU 的内存连一个样本都不下怎么办呢?

答案是使用梯度检查点(gradient-checkpoingting),用计算量来换内存。基本思想就是,在反向传播的过程中,把梯度切分成几部分,分别对网络上的部分参数进行更新(见下图)。但这种方法的速度很慢,因为要增加额外的计算量。但在某些例子上又很有用,比如训练长序列的 RNN 模型等(感兴趣的话可以参考这篇文章)。

图片来自:https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9

这里就不展开讲了,可以参考 PyTorch 官方文档对 Checkpoint 的描述:https://pytorch.org/docs/stable/checkpoint.html

多 GPU 训练方法

简单来讲,PyTorch 中多 GPU 训练的方法是使用 torch.nn.DataParallel。非常简单,只需要一行代码:

在使用torch.nn.DataParallel 的过程中,我们经常遇到一个问题:第一个GPU的计算量往往比较大。我们先来看一下多 GPU 的训练过程原理:

在上图第一行第四个步骤中,GPU-1 其实汇集了所有 GPU 的运算结果。这个对于多分类问题还好,但如果是自然语言处理模型就会出现问题,导致 GPU-1 汇集的梯度过大,直接爆掉。

那么就要想办法实现多 GPU 的负载均衡,方法就是让 GPU-1 不汇集梯度,而是保存在各个 GPU 上。这个方法的关键就是要分布化我们的损失函数,让梯度在各个 GPU 上单独计算和反向传播。这里又一个开源的实现:https://github.com/zhanghang1989/PyTorch-Encoding。这里是一个修改版,可以直接在我们的代码里调用:地址。实例:

如果你的网络输出是多个,可以这样分解:

如果有时候不想进行分布式损失函数计算,可以这样手动汇集所有结果:

下图展示了负载均衡以后的原理:

此文由 PyTorch 中文网 整理自 Medium

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(4)

    1. 一般来讲,如果问题本身增加batch可以提升效果的话,梯度累加也可以。只有当增加batch后GPU内存不够的情况下,梯度累加才可以发挥优势。如果理解不对的话请指正:)

返回顶部