深度学习模型的超参数搜索和微调一直以来是最让我们头疼的一件事,也是最繁琐耗时的一个过程。现在好在已经有一些工具可以帮助我们进行自动化搜索,比如今天要介绍的 Tune。
现在通常用的比较多的超参数搜索算法有 Population Based Training (PBT), HyperBand, 和 ASHA 等。其中 PBT 由 DeepMind 提出,并在很多模型上取得了进展(如下图),详情参考这个链接。
但现在的问题是,很多研究者和团队并没有在自己的项目中使用这些搜索算法,很多超参数搜索框架也没有继承这些最新的算法。有时候模型规模一上去,一些框架就会变得很繁琐。
今天我们介绍一个很好很强大的超参数搜索框架 Tune:
功能对比
下表展示了 Tune 和现在其他几款超参数搜索框架的对比:
MNIST 示例
下面是一个入门示例,只要添加不超过 10 行代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch.optim as optim from ray import tune from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test def train_mnist(config): train_loader, test_loader = get_data_loaders() model = ConvNet() optimizer = optim.SGD(model.parameters(), lr=config["lr"]) for i in range(30): train(model, optimizer, train_loader) acc = test(model, test_loader) tune.track.log(mean_accuracy=acc) # 添加的代码 # 添加如下代码 analysis = tune.run( train_mnist, num_samples=10, # Uncomment this to let each evaluation use 1 GPU # resources_per_trial={"CPU": 1, "GPU": 1}, config={"lr": tune.grid_search([0.001, 0.01, 0.1])}) print("Best config: ", analysis.get_best_config(metric="mean_accuracy")) # 获取结果的 dataframe df = analysis.dataframe() |
运行结果:
此外,Tune 还与其它很多工具可以很好的集成,比如 MLFlow 和 Tensorboard。
如果你有自己实现的优化器,也可以继承 Tune 提供的优化器接口。Tune 还可以与 HyperOpt(如下面的代码示例) 或 Ax 很好的集成。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from hyperopt import hp from ray.tune.suggest.hyperopt import HyperOptSearch space = { "lr": hp.loguniform("lr", 1e-10, 0.1), "momentum": hp.uniform("momentum", 0.1, 0.9), } hyperopt_search = HyperOptSearch( space, max_concurrent=2, reward_attr="mean_accuracy") analysis = tune.run( train_mnist, num_samples=10, search_alg=hyperopt_search) |
PyTorch 示例
下面我们用一个具体示例来说明如何用 PyTorch 和 Tune 实现 early stopping (ASHA)。
首先要安装 Ray,因为 Tune 是 Ray 的一个组件:
1 | pip install ray torch torchvision |
导入一些包:
1 2 3 4 5 6 7 8 9 10 11 12 | import numpy as np import torch import torch.optim as optim import torch.nn as nn import torch.nn.functional as F from ray import tune from ray.tune.examples.mnist_pytorch import get_data_loaders, train, test EPOCH_SIZE = 512 TEST_SIZE = 256 |
我们先用 PyTorch 实现一个简单的 CNN:
1 2 3 4 5 6 7 8 9 10 11 | class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3) self.fc = nn.Linear(192, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 3)) x = x.view(-1, 192) x = self.fc(x) return F.log_softmax(x, dim=1) |
然后用 PyTorch 实现训练循环,并记录 log 到 Tune:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | def train_mnist(config): model = ConvNet() train_loader, test_loader = get_data_loaders() optimizer = optim.SGD( model.parameters(), lr=config["lr"], momentum=config["momentum"] ) for i in range(10): train(model, optimizer, train_loader) acc = test(model, test_loader) tune.track.log(mean_accuracy=acc) # 注意这一行 if i % 5 == 0: torch.save(model, "./model.pth") |
其中 get_data_loader()
函数如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | def get_data_loaders(): mnist_transforms = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) with FileLock(os.path.expanduser("~/data.lock")): train_loader = torch.utils.data.DataLoader( datasets.MNIST( "~/data", train=True, download=True, transform=mnist_transforms), batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST("~/data", train=False, transform=mnist_transforms), batch_size=64, shuffle=True) return train_loader, test_loader |
其中 train()
和 test()
函数定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | def train(model, optimizer, train_loader, device=torch.device("cpu")): model.train() for batch_idx, (data, target) in enumerate(train_loader): if batch_idx * len(data) > EPOCH_SIZE: return data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() def test(model, data_loader, device=torch.device("cpu")): model.eval() correct = 0 total = 0 with torch.no_grad(): for batch_idx, (data, target) in enumerate(data_loader): if batch_idx * len(data) > TEST_SIZE: break data, target = data.to(device), target.to(device) outputs = model(data) _, predicted = torch.max(outputs.data, 1) total = target.size(0) correct = (predicted == target).sum().item() return correct / total |
运行 Tune
针对上面的例子,我们先来运行一次搜索:
1 2 3 4 5 | search_space = { "lr": tune.choice([0.001, 0.01, 0.1]), "momentum": tune.uniform(0.1, 0.9) } analysis = tune.run(train_mnist, config=search_space) |
打印展示一下运行结果:
1 2 3 4 5 6 | dfs = analysis.trial_dataframes # Plot by epoch ax = None # This plots everything on the same plot for d in dfs.values(): ax = d.mean_accuracy.plot(ax=ax, legend=False) |
集成 early stopping (ASHA)
我们接下来集成 ASHA,一个针对 early stopping 的算法(论文地址)。ASHA 可以把 Tune 运行过程中相对较差的搜索结束掉,分配更多的时间和资源给其它搜索。
可以用 num_sample
讲所有搜索并行到多个可用的核上:
1 2 3 4 5 6 7 | from ray.tune.schedulers import ASHAScheduler analysis = tune.run( train_mnist, num_samples=30, scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"), config=search_space) |
运行结束后,可以用同样的方法可视化 dataframe:
此外,Tune 还支持大规模分布式超参数搜索,这里就不做过多介绍,感兴趣的读者可以去查看官方文档。
相关资源
本站微信群、QQ群(三群号 726282629):
不支持windos