您的位置 首页 PyTorch 教程

用 PyTorch 实现一个鲜花分类器

PyTorch入门实战教程

我们今天来训练一个模型识别 102 种花的种类,给定一个花的照片,可以识别出花名。

数据集

这次用到的鲜花数据集来自 http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html,需要的朋友请自行下载。

数据集包含 7370 幅花的图片,下面是一些示例:

我们把数据集分为下列三部分:

数据增广

这里对 Training 的数据,采用了四种增广方式,比例分别为:

然后,随机裁剪(random crop)后缩放到 224×224 的大小。

对于 Testing 和 Validation,只进行缩放,到 224×224 的大小。

创建分类器

这里的网络我们采用一个预训练网络 ResNet-152 模型,但是我们只保留卷积层的权重,最后的分类器要替换成一个输出 102 类的全连接网络,这里的全连接网络我们有三层,输入层 2048,隐藏层 1000,输出层 102。网络示意图如下:

训练网络

这里经过多次尝试后,我们决定最后选取 Adagrad 优化器,来保证训练效果最好。

整个训练过程分为三个阶段。

阶段一

这一阶段集中训练分类器,就是最后的全连接层。

因为前面的卷积层都是预训练的,可以很好地提取图像特征,因此我们先冻结卷积层的权重,只对最后一部分全连接层进行训练。

这一阶段训练了 39 个 epoch,最后在 validation 得到 94.62% 的准确率:

阶段二

现在我们可以说全连接层的分类器已经训练好了,那接下来我们微调卷积层的权重。

这时候,取消冻结前面卷积层的权重,对整个网络全部进行训练,20 个 epoch 后,validation 得到 96.58% 的准确率:

阶段三

虽然这时候网络性能已经非常好了,但我们还想继续提高一下。

这次我们将 learning rate 设置为 0.000001,只是为了对网络的权重做极其微小的调整。在训练了 10 个 epoch 之后,我们在 validation 上得到了 97.07% 的准确率:

充分证明我们的三个阶段的训练是有效的。

下面是训练曲线和学习率曲线:

结果

我们最后在 Testing 上做测试,我们的网络竟然得到了 99.27% 的准确率,已经非常不错了!

下图是一些测试结果:

代码地址

如果对代码感兴趣,可以到 Github 上获取,地址 https://github.com/silviomori/udacity-deeplearning-pytorch-challenge-lab

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

返回顶部