我们这节教程用 CIFAR10 图像分类作为例子,讲解 fastai 中 ConvLearner 的用法。如果对 fastai 还不熟悉,可以参考下面两篇教程,文末有本文代码的 jupyter notebook 供大家自己测试。
准备 CIFAR10 数据
在新版的 fastai 里,我们可以使用 untar_data 来自动下载一些预定义好的数据。首先当然是引入必要的包:
1 2 3 4 | from fastai import * from fastai.vision import * from fastai.vision.models.wrn import wrn_22 from fastai.docs import * |
下载数据:
1 | untar_data(CIFAR_PATH) |
如果打印 CIFAR_PATH 的值,发现是:
1 | PosixPath('../data/cifar10') |
到相应的文件夹,可以看到主要是两个文件夹:train 和 test,还有一个 labels.txt 包含了所有的 label 数据:
我们可以打印一张图看看:
1 2 | img = open_image('../data/cifar10/train/airplane/30_airplane.png') img |
初始化模型
我们先定义一组对图像的增广:
1 | ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], []) |
然后我们用新的 image_data_from_folder
函数来从数据文件夹读取数据,并做图像增广:
1 | data = image_data_from_folder(CIFAR_PATH, valid='test', ds_tfms=ds_tfms, tfms=cifar_norm, bs=512) |
关于image_data_from_folder
的用法,可以参考官方 API 文档。
接下来,我们定义 Learner 类(Learner 类用法):
1 | learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16() |
这里的 wrn_22
指的是 WideResNet,我们这里载入预训练模型。这个文件详细定义了 wrn_22 的模型结构。
训练
我们选择用 Learner 类的 fit
方法来训练:
1 | learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10, pct_start=0.5) |
代码很简洁,只需要一行,这也是 fastai 库的魅力所在,可以快速实现很多原型的测试与复用。
Notebook 下载
本节的 Jupyter notebook 可到 Github 下载。
本站微信群、QQ群(三群号 726282629):