PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST

PyTorch入门实战教程

自从 PyTorch C 接口发布以来,很少有教程专门针对这方面讲解。我们 PyTorch 中文网今天开始整理一套 PyTorch C API 系列教程,供大家参考。内容均源于网络。

完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):

  1. PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST
  2. PyTorch C++ API 系列 2:使用自定义数据集
  3. PyTorch C++ API 系列 3:训练网络
  4. PyTorch C++ API 系列 4:实现猫狗分类器(一)
  5. PyTorch C++ API 系列 5:实现猫狗分类器(二)

本文讲解如何用 PyTorch C 实现 VGG-16 来识别 MNIST 数据集。

安装

首先下载 libtorch

CPU 版本:

GPU (CUDA 9.0) 版本:

GPU (CUDA 10.0) 版本:

然后将下载的压缩包解压缩。后面我们将使用解压后后的文件夹的绝对路径。

实现

VGG-16 的网络结构如下:

png

首先引入头文件:

然后实现网络定义:

训练

接下来我们测试训练网络,我们训练 10 个 epoch,学习率 0.01,使用 nll_loss 损失函数:

完整代码请参考:https://github.com/krshrimali/Digit-Recognition-MNIST-SVHN-PyTorch-CPP

参考资料

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

评论列表(1)

返回顶部