PyTorch 中内存泄漏的典型现象就是数据并不大,但 GPU 的内存已经被占满,而且 GPU 的利用率(utilization)很低。为了更好的处理这个问题,我们总结一下 PyTorch 使用过程中的关于减少内存泄漏的实践经验。
文章目录
实验场景
我们的实验场景如下:
- PyTorch
- 用 DataLoader 读取训练集数据
- 使用 TorchVision 做数据预处理
- 使用的 GPU:Google’s Colab Pro with Tesla P100-PCIE-16GB GPU
模型输入是 128×128 的图像,训练集大概有 122k 张图片,校验集大概有 22k 张图片。
经验1:对 Loss 的处理
通常,在训练过程中,我们都是将 loss 的添加到一个 list 里保存。记住在保存前,先 detach,然后仅使用其数值。否则,你添加的就不仅仅是 loss,而是整个计算图。
正确用法:
1 2 3 | loss = F.mse_loss(prd, true) epoch_loss += loss.detach().item() training_log.append(epoch_loss) |
错误用法:
1 2 3 | loss = F.mse_loss(prd, true) epoch_loss += loss training_log.append(epoch_loss) |
经验2:将模型、输入、输出加载到 CUDA
避免机器内存的暴涨,记得把模型和从 dataloader 读取的输入数据放到 CUDA 里再使用。
正确用法:
1 2 3 4 5 6 | model = MyModel() model = model.to(device) for batch_idx, (x,y) in enumerate(train_loader): x = x.to(device) y = y.to(device) prd = model(x) |
错误用法:
1 2 3 | model = MyModel() for batch_idx, (x,y) in enumerate(train_loader): prd = model(x) |
经验3:使用垃圾回收
Python 在内存垃圾回收方面做的可能不太好,不用的变量往往不会被立即回收。要做到立即回收,最好在每个训练循环里加入下面的代码:
1 2 | import gc gc.collect() |
这个带来的效果可能微乎其微,但可以保证高效的垃圾回收。
经验4:DataLoader 的 worker 数量不是越多越好
如果你使用了多个 worker 读取数据,记住这个数并不是越多越好。很多的 worker 可能会因为进程协作的问题或者 IO 的问题而拖慢速度。
正确用法:
1 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_worker = [一个合理的数字]) |
错误用法:
1 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_worker = [一个特别大的数,比如 50,100 等]) |
关于多少 worker 才是最合适的,可以参考官方论坛的讨论帖:地址。
经验5:把数据保存在 Numpy Array 里,而不是 List 里
这个问题的官方讨论在这里:链接,关于解决方案引用如下:
Python lists store only references to the objects. The objects are kept separately in memory. Every object has a refcount, therefore every item in the list has a refcount.
Numpy arrays (of standard np types) are stored as continuous blocks in memory and are only ONE object with one refcount.
This changes if you make the NumPy array explicitly of type object, which makes it start behaving like a regular Python list (only storing references to (string) objects). The same “problems” with memory consumption now appear.”
所以,如果的在 DataLoader 中的数据是保存在 list 里的,记得用 np.array(x)
转换成 Numpy Array。
如果你还有其它的经验,欢迎留言分享:)
本站微信群、QQ群(三群号 726282629):