您的位置 首页 PyTorch 教程

用 Tune 快速进行超参数优化(Hyperparameter Tuning)

PyTorch入门实战教程

深度学习模型的超参数搜索和微调一直以来是最让我们头疼的一件事,也是最繁琐耗时的一个过程。现在好在已经有一些工具可以帮助我们进行自动化搜索,比如今天要介绍的 Tune

现在通常用的比较多的超参数搜索算法有 Population Based Training (PBT), HyperBand, 和 ASHA 等。其中 PBT 由 DeepMind 提出,并在很多模型上取得了进展(如下图),详情参考这个链接

但现在的问题是,很多研究者和团队并没有在自己的项目中使用这些搜索算法,很多超参数搜索框架也没有继承这些最新的算法。有时候模型规模一上去,一些框架就会变得很繁琐。

今天我们介绍一个很好很强大的超参数搜索框架 Tune

功能对比

下表展示了 Tune 和现在其他几款超参数搜索框架的对比:

MNIST 示例

下面是一个入门示例,只要添加不超过 10 行代码:

运行结果:

此外,Tune 还与其它很多工具可以很好的集成,比如 MLFlow 和 Tensorboard。

如果你有自己实现的优化器,也可以继承 Tune 提供的优化器接口。Tune 还可以与 HyperOpt(如下面的代码示例) 或 Ax 很好的集成。

PyTorch 示例

下面我们用一个具体示例来说明如何用 PyTorch 和 Tune 实现 early stopping (ASHA)。

首先要安装 Ray,因为 Tune 是 Ray 的一个组件:

导入一些包:

我们先用 PyTorch 实现一个简单的 CNN:

然后用 PyTorch 实现训练循环,并记录 log 到 Tune:

其中 get_data_loader() 函数如下:

其中 train()test() 函数定义如下:

运行 Tune

针对上面的例子,我们先来运行一次搜索:

打印展示一下运行结果:

集成 early stopping (ASHA)

我们接下来集成 ASHA,一个针对 early stopping 的算法(论文地址)。ASHA 可以把 Tune 运行过程中相对较差的搜索结束掉,分配更多的时间和资源给其它搜索。

可以用 num_sample 讲所有搜索并行到多个可用的核上:

运行结束后,可以用同样的方法可视化 dataframe:

此外,Tune 还支持大规模分布式超参数搜索,这里就不做过多介绍,感兴趣的读者可以去查看官方文档。

相关资源

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(1)

返回顶部