PyTorch C++ API 系列 3:训练网络

PyTorch入门实战教程

上一篇教程中,我们展示了如何用 PyTorch C API 自定义数据集的读取并定义了一个简单的 VGG-16 网络。这一节我们看看如何训练网络。

完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):

  1. PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST
  2. PyTorch C++ API 系列 2:使用自定义数据集
  3. PyTorch C++ API 系列 3:训练网络
  4. PyTorch C++ API 系列 4:实现猫狗分类器(一)
  5. PyTorch C++ API 系列 5:实现猫狗分类器(二)

我们来简单做个回顾,用下面的代码模板可是实现 PyTorch C 自定义数据读取的一般步骤:

下图展示了这个过程的一般步骤:

回顾结束,我们来看看怎么用 PyTorch C++ API 训练网络吧。

怎样传批数据(Batch)?

在传统的 Python 中,我们需要先定义一个 data loader,如下:

同样的,在 C++ API 里,我们这样做:

这里面的 SequentialSampler 类负责按照我们提供的数据顺序来生成样本,关于这个类的文档可以参考这里

关于函数 torch::data::make_data_loader 可以在这里找到相关文档。

在训练循环中,可以这样读取 data 和 target:

这里每次取得的数据大小取决于之前 torch::data::make_data_loader() 函数中传入的 batch_size 大小。

定义超参数

这里说到超参数,我们通常是指:

  • Batch Size
  • Optimizer 优化器
  • Loss Function 损失函数

Batch Size

对于 Batch size 我们上一节讲过如何定义了。

优化器

这里我们看一下如何定义优化器:

目前 PyTorch C++ API 仅支持以下优化器:

  1. RMSprop
  2. SGD
  3. Adam
  4. Adagrad
  5. LBFGS
  6. LossClosureOptimizer

损失函数

关于损失函数,我们这里以 nll_loss 为例:

如果要输出 loss 的具体值,可以使用 loss.item<float>()

开始训练

训练循环部分基本上逻辑和 Python 中是一样的,见代码:

注意最后采用 torch::save() 来保存训练好的模型。

下一节我们将用 PyTorch C++ API 实现一个完整的猫狗分类器,敬请期待。

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表回复

您的电子邮箱地址不会被公开。

评论列表(2)

返回顶部