我们在上文中介绍了 fastai 的安装,本文将带领大家通过 MNIST 的例子快速上手 fastai。
引入包
首先我们引入必要的一些包:
1 2 3 | from fastai import * # 大多数通用的函数 from fastai.vision import * # 大多数与计算机视觉有关的函数 from fastai.docs import * # 这个包提供一些示例数据集 |
fastai 中最重要的是 Leaner 类,我们先来熟悉一下 Learner 中的参数:
1 | doc(Learner) |
结果如下:
1 | Learner(data:DataBunch, model:Module, opt_fn:Callable=\'Adam\', loss_fn:Callable=\'cross_entropy\', metrics:Collection[Callable]=None, true_wd:bool=True, bn_wd:bool=True, wd:Floats=0.01, train_bn:bool=True, path:str=None, model_dir:str=\'models\', callback_fns:Collection[Callable]=None, callbacks:Collection[Callback]=, layer_groups:ModuleList=None) |
感兴趣的读者可以在这里查看官方文档。
创建一个 DataBunch
图像数据可以是一个标记好的文件夹,也可以是一个单独的文件和一个 csv 文件。
1 2 | untar_data(MNIST_PATH) MNIST_PATH |
输出结果:
1 | PosixPath(\'../data/mnist_sample\') |
创建一个 DataBunch:
1 2 3 | data = image_data_from_folder(MNIST_PATH) img,label = data.train_ds[0] img |
输出结果:
开始训练
创建一个 Learner 类,调用 fit 函数进行拟合:
1 2 | learn = ConvLearner(data, tvm.resnet18, metrics=accuracy) learn.fit(1) |
输出结果:
1 2 3 4 | VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value… Total time: 00:06 epoch train loss valid loss accuracy 0 0.088343 0.045195 0.986261 (00:06) |
本站微信群、QQ群(三群号 726282629):
