大家都知道,PyTorch 从 1.2.0 版本开始,正式自带内置的 Tensorboard 支持了,我们可以不再依赖第三方工具来进行可视化。
本文将介绍 PyTorch 1.2.0 中自带 Tensorboard 的基本使用方法。
安装
PyTorch 的版本需要 1.2.0 :
1 | pip install --upgrade torch torchvision |
然后安装 Tensorboard 1.14 :
1 | pip install tensorboard |
安装完成后,应该可以引入响应包:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | Python 3.7.4 (default, Aug 13 2019, 20:35:49) [GCC 7.3.0] :: Anaconda, Inc. on linux Type "help", "copyright", "credits" or "license" for more information. >>> from torch.utils.tensorboard import SummaryWriter /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. _np_qint8 = np.dtype([("qint8", np.int8, 1)]) /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. _np_quint8 = np.dtype([("quint8", np.uint8, 1)]) /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. _np_qint16 = np.dtype([("qint16", np.int16, 1)]) /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. _np_quint16 = np.dtype([("quint16", np.uint16, 1)]) /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. _np_qint32 = np.dtype([("qint32", np.int32, 1)]) /home/seungjaeryanlee/.conda/envs/torchtest/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or \\'1type\\' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\'(1,)type\\'. np_resource = np.dtype([("resource", np.ubyte, 1)]) >>> |
上面的 warning 可以忽略。
引入相应的 Writer
要使用 Tensorboard 需要在 Python 代码中引入 Writer 类,并定义输出路径:
1 2 | from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(PATH_to_log_dir) |
如果输出路径不存在会被自动创建。
运行 Tensorboard
在命令行中运行下列命令启动 Tensorboard:
1 | tensorboard --logdir=/path_to_log_dir/ --port 6006 |
请将命令中相应的路径和端口号改成自己需要的即可。
使用方法
我们这里用 Fashion-MNIST 数据库来做示范,展示一下 Tensorboard 的基本使用方法。Fashion-MNIST 的数据库可以从这里下载。
这里的神经网络结构非常简单,仅做示范用,包含两个卷积层,两个全连接层,输出 10 个分类,激活函数是 ReLU:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | class MNISTClass(nn.Module): def __init__(self): super(MNISTClass, self).__init__() self.conv1 = nn.Conv2d(1, 15, kernel_size=3, stride=1) self.conv2 = nn.Conv2d(15, 30, kernel_size=3, stride=2) self.fc1 = nn.Linear(1080, 100) self.fc2 = nn.Linear(100, 10) def forward(self, x): # conv1(kernel=3, filters=15) 28x28x1 -> 26x26x15 x = F.relu(self.conv1(x)) # conv2(kernel=3, filters=20) 26x26x15 -> 13x13x30 # max_pool(kernel=2) 13x13x30 -> 6x6x30 x = F.relu(F.max_pool2d(self.conv2(x), 2, stride=2)) # flatten 6x6x30 = 1080 x = x.view(-1, 1080) # 1080 -> 100 x = F.relu(self.fc1(x)) # 100 -> 10 x = self.fc2(x) # transform to logits return F.log_softmax(x, dim=1) |
记录 Scalars
我们常用的 loss、accuracy 等都是数值,我们在 Tensorboard 中记录数值的方法也很简单:
1 | add_scalar(tag, scalar_value, global_step=None, walltime=None) |
其中 tag
是这个常数值所属的标签(比如 training_loss 等)。常用的一个方法是将 tag
值设置为 section/plot
的格式,这样 Tensorboard 会按照 section
来给结果分组(下面有例子)。
global_step
是一个整数,通常是曲线图里的 x 轴,如果不设置则默认一直为 0。注意这里是不存在覆盖的,就是对于同一个 global_step
值,新的值不会覆盖旧的值,而是会同时画到图上。
walltime
就是记录的时间戳,默认是系统当前时间 time.time()
。
对于 Fashion-MNIST 数据集,我们主要记录 Training 和 Testing 的 loss,以及 Testing Accuracy:
1 2 3 4 5 6 7 8 9 | # Record training loss from each epoch into the writer writer.add_scalar(\\'Train/Loss\\', loss.item(), epoch) writer.flush() # Record loss and accuracy from the test run into the writer writer.add_scalar(\\'Test/Loss\\', test_loss, epoch) writer.add_scalar(\\'Test/Accuracy\\', accuracy, epoch) writer.flush() |
输出结果如下:
记录图像
记录图像的语法为:
1 | add_image(tag, img_tensor, global_step=None, walltime=None, dataformats=\'CHW\') |
这里的 tag
含义和 Scalar 中一样,global_step
更像是给图像的一个标签。
我们的例子中,首先将多张图像制作成网格拼图,然后输出到 Tensorboard 中:
1 2 3 4 | # To inspect the input dataset visualize the grid grid = utils.make_grid(images) writer.add_image(\'Dataset/Inspect input grid\', grid, global_step=0) writer.close() |
输出结果如下:
可以利用 global_step
来输出训练过程每个 epoch 的图像,这样可以有一个过程中的对比。
记录直方图 Histogram
记录直方图的语法为:
1 | add_histogram(tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None) |
同样 tag
的含义还是一样,global_step
这里表现为一系列直方图叠加排列在一起的序列,values
可以是一系列数值,也可以是 numpy array。
总结
当然,PyTorch 内置的 Tensorboard 还有很多缺点,比如经常会画不出网络结构图,相信在未来的版本中会越来越完善。
除此以外,还有第三方工具可以实现可视化,见下列参考文章:
本站微信群、QQ群(三群号 726282629):