在上一篇文章中,我们讨论了如何使用 PyTorch C++ API 实现 VGG-16 来识别 MNIST 数据集。这篇文章我们讨论一下如何用 C++ API 使用自定义数据集。
完整的 PyTorch C++ 系列教程目录如下(或者点击这里查看):
- 《PyTorch C++ API 系列 1:用 VGG-16 识别 MNIST》
- 《PyTorch C++ API 系列 2:使用自定义数据集》
- 《PyTorch C++ API 系列 3:训练网络》
- 《PyTorch C++ API 系列 4:实现猫狗分类器(一)》
- 《PyTorch C++ API 系列 5:实现猫狗分类器(二)》
概览
我们先来看一下上一篇教程中我们是怎么读取数据的:
1 2 3 | auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081))).map( torch::data::transforms::Stack<>()), 64); |
我们来细细讲解。
首先,我们将数据集读入 tensor:
1 | auto data_set = torch::data::datasets::MNIST("../data"); |
接下来,我们应用一些 transforms
:
1 | auto data_set = data_set.map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(torch::data::transforms::Stack<>()) |
我们 batch_size
为 64:
1 | std::move(data_set, 64); |
然后我们就可以将数据传给 data loader 然后由 data loader 传给网络。
工作原理
我们需要了解一下这背后到底是怎么工作的,因此,我们看一下 MNIST 读取的源码文件 torch::data::datasets::MNIST
类,源码地址在这里:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | namespace torch { namespace data { namespace datasets { /// The MNIST dataset. class TORCH_API MNIST : public Dataset<MNIST> { public: /// The mode in which the dataset is loaded. enum class Mode { kTrain, kTest }; /// Loads the MNIST dataset from the root path. /// /// The supplied root path should contain the content of the unzipped /// MNIST dataset, available from http://yann.lecun.com/exdb/mnist. explicit MNIST(const std::string& root, Mode mode = Mode::kTrain); /// Returns the Example at the given index. Example<> get(size_t index) override; /// Returns the size of the dataset. optional<size_t> size() const override; /// Returns true if this is the training subset of MNIST. bool is_train() const noexcept; /// Returns all images stacked into a single tensor. const Tensor& images() const; /// Returns all targets stacked into a single tensor. const Tensor& targets() const; private: Tensor images_, targets_; }; } // namespace datasets } // namespace data } // namespace torch |
对于 MNIST
类的构造器:
1 2 3 | MNIST::MNIST(const std::string& root, Mode mode) : images_(read_images(root, mode == Mode::kTrain)), targets_(read_targets(root, mode == Mode::kTrain)) {} |
我们可以看到这里调用了两个函数:
read_images(root, mode)
读取图像read_targets(root, mode)
读取图像标签
我们来看看这两个函数具体怎么工作。
read_images(root, mode)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | Tensor read_images(const std::string& root, bool train) { // kTrainImagesFilename and kTestImagesFilename are specific to MNIST dataset here // No need for using join_paths here const auto path = join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename); // Load images std::ifstream images(path, std::ios::binary); TORCH_CHECK(images, "Error opening images file at ", path); // kTrainSize = len(training data) // kTestSize = len(testing_data) const auto count = train ? kTrainSize : kTestSize; // Specific to MNIST data // From http://yann.lecun.com/exdb/mnist/ expect_int32(images, kImageMagicNumber); expect_int32(images, count); expect_int32(images, kImageRows); expect_int32(images, kImageColumns); // This converts images to tensors // Allocate an empty tensor of size of image (count, channels, height, width) auto tensor = torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte); // Read image and convert to tensor images.read(reinterpret_cast<char*>(tensor.data_ptr()), tensor.numel()); // Normalize the image from 0 to 255 to 0 to 1 return tensor.to(torch::kFloat32).div_(255); } |
read_targets(root, mode)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | Tensor read_targets(const std::string& root, bool train) { // Specific to MNIST dataset (kTrainImagesFilename and kTestTargetsFilename) const auto path = join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename); // Read the labels std::ifstream targets(path, std::ios::binary); TORCH_CHECK(targets, "Error opening targets file at ", path); // kTrainSize = len(training_labels) // kTestSize = len(testing_labels) const auto count = train ? kTrainSize : kTestSize; expect_int32(targets, kTargetMagicNumber); expect_int32(targets, count); // Allocate an empty tensor of size of number of labels auto tensor = torch::empty(count, torch::kByte); // Convert to tensor targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count); return tensor.to(torch::kInt64); } |
还有一些辅助函数:
1 2 3 4 5 6 7 | Example<> MNIST::get(size_t index) { return {images_[index], targets_[index]}; } optional<size_t> MNIST::size() const { return images_.size(0); } |
上面两个函数分别用于获取一个图像及其标签,和返回数据集大小。
自定义数据集的流程
通过上面的源代码查看,我们知道了自定义数据集的大概流程:
- 读取数据和标签
- 转换成 tensor
- 定义
get()
和size()
两个函数 - 初始化类
- 将类实例传给 data loader
自定义数据集示例
接下来我们看一个具体示例。下面是整个代码的大体框架:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | #include <ATen/ATen.h> #include <torch/torch.h> #include <iostream> #include <vector> #include <tuple> #include <opencv2/opencv.hpp> #include <string> /* Convert and Load image to tensor from location argument */ torch::Tensor read_data(std::string location) { // Read Data here // Return tensor form of the image return torch::Tensor; } /* Converts label to tensor type in the integer argument */ torch::Tensor read_label(int label) { // Read label here // Convert to tensor and return return torch::Tensor; } /* Loads images to tensor type in the string argument */ vector<torch::Tensor> process_images(vector<string> list_images) { cout << "Reading Images..." << endl; // Return vector of Tensor form of all the images return vector<torch::Tensor>; } /* Loads labels to tensor type in the string argument */ vector<torch::Tensor> process_labels(vector<string> list_labels) { cout << "Reading Labels..." << endl; // Return vector of Tensor form of all the labels return vector<torch::Tensor>; } class CustomDataset : public torch::data::dataset<CustomDataset> { private: // Declare 2 vectors of tensors for images and labels vector<torch::Tensor> images, labels; public: // Constructor CustomDataset(vector<string> list_images, vector<string> list_labels) { images = process_images(list_images); labels = process_labels(list_labels); }; // Override get() function to return tensor at location index torch::data::Example<> get(size_t index) override { torch::Tensor sample_img = images.at(index); torch::Tensor sample_label = labels.at(index); return {sample_img.clone(), sample_label.clone()}; }; // Return the length of data torch::optional<size_t> size() const override { return labels.size(); }; }; |
这里我们使用 OpenCV 来读取图像数据,读取的方法相对比较简单:
1 | cv::imread(std::string location, int) |
注意要转换成 PyTorch 使用的 tensor 顺,即 batch_size, channels, height, width
:
1 2 3 4 5 6 7 8 9 10 | torch::Tensor read_data(std::string loc) { // Read Image from the location of image cv::Mat img = cv::imread(loc, 1); // Convert image to tensor torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte); img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width return img_tensor.clone(); }; |
读取标签:
1 2 3 4 5 | // Read Label (int) and convert to torch::Tensor type torch::Tensor read_label(int label) { torch::Tensor label_tensor = torch::full({1}, label); return label_tensor.clone(); } |
最终代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #include <ATen/ATen.h> #include <torch/torch.h> #include <iostream> #include <vector> #include <tuple> #include <opencv2/opencv.hpp> #include <string> /* Convert and Load image to tensor from location argument */ torch::Tensor read_data(std::string loc) { // Read Data here // Return tensor form of the image cv::Mat img = cv::imread(loc, 1); cv::resize(img, img, cv::Size(1920, 1080), cv::INTER_CUBIC); std::cout << "Sizes: " << img.size() << std::endl; torch::Tensor img_tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte); img_tensor = img_tensor.permute({2, 0, 1}); // Channels x Height x Width return img_tensor.clone(); } /* Converts label to tensor type in the integer argument */ torch::Tensor read_label(int label) { // Read label here // Convert to tensor and return torch::Tensor label_tensor = torch::full({1}, label); return label_tensor.clone(); } /* Loads images to tensor type in the string argument */ vector<torch::Tensor> process_images(vector<string> list_images) { cout << "Reading Images..." << endl; // Return vector of Tensor form of all the images vector<torch::Tensor> states; for (std::vector<string>::iterator it = list_images.begin(); it != list_images.end(); ++it) { torch::Tensor img = read_data(*it); states.push_back(img); } return states; } /* Loads labels to tensor type in the string argument */ vector<torch::Tensor> process_labels(vector<string> list_labels) { cout << "Reading Labels..." << endl; // Return vector of Tensor form of all the labels vector<torch::Tensor> labels; for (std::vector<int>::iterator it = list_labels.begin(); it != list_labels.end(); ++it) { torch::Tensor label = read_label(*it); labels.push_back(label); } return labels; } class CustomDataset : public torch::data::dataset<CustomDataset> { private: // Declare 2 vectors of tensors for images and labels vector<torch::Tensor> images, labels; public: // Constructor CustomDataset(vector<string> list_images, vector<string> list_labels) { images = process_images(list_images); labels = process_labels(list_labels); }; // Override get() function to return tensor at location index torch::data::Example<> get(size_t index) override { torch::Tensor sample_img = images.at(index); torch::Tensor sample_label = labels.at(index); return {sample_img.clone(), sample_label.clone()}; }; // Return the length of data torch::optional<size_t> size() const override { return labels.size(); }; }; int main(int argc, char** argv) { vector<string> list_images; // list of path of images vector<int> list_labels; // list of integer labels // Dataset init and apply transforms - None! auto custom_dataset = CustomDataset(list_images, list_labels).map(torch::data::transforms::Stack<>()); } |
在下一篇教程中,我们将介绍如何在 CNN 中使用自定义的 data loader。
本站微信群、QQ群(三群号 726282629):