在上一篇教程中,我们展示了如何用 PyTorch C API 自定义数据集的读取并定义了一个简单的 VGG-16 网络。这一节我们看看如何训练网络。
完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):
- 《PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST》
- 《PyTorch C++ API 系列 2:使用自定义数据集》
- 《PyTorch C++ API 系列 3:训练网络》
- 《PyTorch C++ API 系列 4:实现猫狗分类器(一)》
- 《PyTorch C++ API 系列 5:实现猫狗分类器(二)》
我们来简单做个回顾,用下面的代码模板可是实现 PyTorch C 自定义数据读取的一般步骤:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | class CustomDataset : public torch::data::dataset { private: // Declare 2 vectors of tensors for images and labels vector images, labels; public: // Constructor CustomDataset(vector list_images, vector list_labels) { images = process_images(list_images); labels = process_labels(list_labels); }; // Override get() function to return tensor at location index torch::data::Example<> get(size_t index) override { torch::Tensor sample_img = images.at(index); torch::Tensor sample_label = labels.at(index); return {sample_img.clone(), sample_label.clone()}; }; // Return the length of data torch::optional size() const override { return labels.size(); }; }; int main(int argc, char** argv) { vector list_images; // list of path of images vector list_labels; // list of integer labels // Dataset init and apply transforms - None! auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>()); } |
下图展示了这个过程的一般步骤:
回顾结束,我们来看看怎么用 PyTorch C++ API 训练网络吧。
怎样传批数据(Batch)?
在传统的 Python 中,我们需要先定义一个 data loader,如下:
1 2 | dataset_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=4, shuffle=True) |
同样的,在 C++ API 里,我们这样做:
1 2 3 4 | auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( std::move(custom_dataset), batch_size ); |
这里面的 SequentialSampler
类负责按照我们提供的数据顺序来生成样本,关于这个类的文档可以参考这里。
关于函数 torch::data::make_data_loader
可以在这里找到相关文档。
在训练循环中,可以这样读取 data 和 target:
1 2 3 4 | for(auto& batch: *data_loader) { auto data = batch.data; auto target = batch.target.squeeze(); } |
这里每次取得的数据大小取决于之前 torch::data::make_data_loader()
函数中传入的 batch_size
大小。
定义超参数
这里说到超参数,我们通常是指:
- Batch Size
- Optimizer 优化器
- Loss Function 损失函数
Batch Size
对于 Batch size
我们上一节讲过如何定义了。
优化器
这里我们看一下如何定义优化器:
1 2 3 | // We need to define the network first auto net = std::make_shared<Net>(); torch::optim::Adam optimizer(net->parameters(), torch::optim::AdamOptions(1e-3)); |
目前 PyTorch C++ API 仅支持以下优化器:
损失函数
关于损失函数,我们这里以 nll_loss
为例:
1 2 3 4 5 | auto output = net->forward(data); auto loss = torch::nll_loss(output, target); // To backpropagate loss loss.backward() |
如果要输出 loss 的具体值,可以使用 loss.item<float>()
。
开始训练
训练循环部分基本上逻辑和 Python 中是一样的,见代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | dataset_size = custom_dataset.size().value(); int n_epochs = 10; // Number of epochs for(int epoch=1; epoch<=n_epochs; epoch++) { for(auto& batch: *data_loader) { auto data = batch.data; auto target = batch.target.squeeze(); // Convert data to float32 format and target to Int64 format // Assuming you have labels as integers data = data.to(torch::kF2); target = target.to(torch::kInt64); // Clear the optimizer parameters optimizer.zero_grad(); auto output = net->forward(data); auto loss = torch::nll_loss(output, target); // Backpropagate the loss loss.backward(); // Update the parameters optimizer.step(); cout << "Train Epoch: %d/%ld [%5ld/%5d] Loss: %.4f" << epoch << n_epochs << batch_index * batch.data.size(0) << dataset_size << loss.item<float>() << endl; } } // Save the model torch::save(net, "best_model.pt"); |
注意最后采用 torch::save()
来保存训练好的模型。
下一节我们将用 PyTorch C++ API 实现一个完整的猫狗分类器,敬请期待。
本站微信群、QQ群(三群号 726282629):
评论列表(2)