libtorch学习历程(四):数据加载模块

火影  金牌会员 | 2024-8-25 14:50:20 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 544|帖子 544|积分 1632

本章将具体先容如何使用libtorch自带的数据加载模块。
自定义数据集
使用自定义数据集
简介

要自定义数据加载模块,需要继承torch::data:ataset这个基类实现派生类。
与pytorch中需要实现初始化函数init获取函数getitem以及数据集大小函数len类似的是,在libtorch中同样需要处理好初始化函数get()函数size()函数
例程的代码结构

例程中使用了一个图像分类使命来进行先容,使用pytorch官网提供的昆虫分类数据集
遍历图像文件

例程中使用了io.h来遍历文件夹。
起首实现遍历文件夹的函数:
担当数据集文件夹路径image_dir图片类型image_type,将遍历到的图片路径和其种别分别存储到list_images和list_labels,末了lable变量用于表示种别计数。
通过该函数,会得到所有图像的绝对地址,通过这些地址就可以获得图像。
  1. #include <io.h>
  2. void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);
  3. void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
  4. {
  5.     /*
  6.      * path:文件夹地址
  7.      * type:图片类型
  8.      * list_images:所有图片的名称
  9.      * list_label:各个图片的标签,也就是所属的类
  10.      * label:类别的个数
  11.     */
  12.     long long hFile = 0; //句柄
  13.     struct _finddata_t fileInfo;// 记录读取到文件的信息
  14.     std::string pathName;
  15.     // 调用_findfirst函数,其第一个参数为遍历的文件夹路径,*代表任意文件。注意路径最后,需要添加通配符
  16.     // 如果失败,返回-1,否则,就会返回文件句柄,并且将找到的第一个文件信息放在_finddata_t结构体变量中
  17.     if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
  18.     {
  19.         return;
  20.     }
  21.     // 通过do{}while循环,遍历所有文件
  22.     do
  23.     {
  24.         const char* filename = fileInfo.name;// 获得文件名
  25.         const char* t = type.data();
  26.         if (fileInfo.attrib&_A_SUBDIR) //是子文件夹
  27.         {
  28.             //遍历子文件夹中的文件(夹)
  29.             if (strcmp(filename, ".") == 0 || strcmp(filename, "..") == 0) //子文件夹目录是.或者..
  30.                 continue;
  31.             std::string sub_path = path + "\" + fileInfo.name;// 增加多一级
  32.             label++;
  33.             load_data_from_folder(sub_path, type, list_images, list_labels, label);// 读取子文件夹的文件
  34.         }
  35.         else //判断是不是后缀为type文件
  36.         {
  37.             if (strstr(filename, t))
  38.             {
  39.                 std::string image_path = path + "\" + fileInfo.name;// 构造图像的地址
  40.                 list_images.push_back(image_path);
  41.                 list_labels.push_back(label);
  42.             }
  43.         }
  44.       //其第一个参数就是_findfirst函数的返回值,第二个参数同样是文件信息结构体
  45.     } while (_findnext(hFile, &fileInfo) == 0);
  46.     return;
  47. }
复制代码
自定义DataSet

需要继承torch::data:ataset,定义私有变量image_paths和labels分别存储图片路径和种别,是两个vector变量。
在构造函数中,调用图像遍历函数来获得所有图像的地址与种别;并且需要重写get()与size()
在get()中根据传入的index来获得指定的图像,而且可以在get()函数中对图像进行一些处理,比方调解大小或数据加强等。然后使用torch::from_blob将图像数据与label都转换为张量。
此中图像还需要使用permute(),将张量转换为Channels x Height x Width的结构、
  1. class myDataset:public torch::data::Dataset<myDataset>{
  2. public:
  3.     int num_classes = 0;
  4.     myDataset(std::string image_dir, std::string type){
  5.         // 调用遍历文件的函数
  6.         load_data_from_folder(image_dir, std::string(type), image_paths, labels, num_classes);
  7.     }
  8.     // 重写 get(),根据传入的index来获得指定的数据
  9.     torch::data::Example<> get(size_t index) override{
  10.         std::string image_path = image_paths.at(index);// 根据index得到指定的图像
  11.         cv::Mat image = cv::imread(image_path);// 读取图像
  12.         cv::resize(image, image, cv::Size(224, 224));// 调整大小,使得尺寸统一,用于张量stack
  13.         int label = labels.at(index);//
  14.         // 将图像数据转换为张量image_tensor,尺寸{image.rows, image.cols, 3},元素的数据类型为byte
  15.         // Channels x Height x Width
  16.         torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 });
  17.         //
  18.         torch::Tensor label_tensor = torch::full({ 1 }, label);
  19.         return {img_tensor.clone(), label_tensor.clone()};// 返回图像及其标签
  20.     }
  21.     // Return the length of data
  22.     torch::optional<size_t> size() const override {
  23.         return image_paths.size();
  24.     };
  25. private:
  26.     std::vector<std::string> image_paths;// 所有图像的地址
  27.     std::vector<int> labels;// 所有图像的类别
  28. };
复制代码
使用自定义数据集

起首创建一个自定义数据集对象,然后它进行一些transform处理
  1. auto mydataset = myDataset(image_dir,".jpg").map(torch::data::transforms::Stack<>());
复制代码
然后需要使用torch::data::make_data_loader来传入批数据(Batch),对应于pytorch中的torch.utils.data.DataLoader

这内里的 SequentialSampler 类负责按照我们提供的数据顺序来生成样本。
需要传入数据集对象与批次尺寸
  1. auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
  2.                                                 std::move(mydataset), batch_size);
复制代码
然后可以通过循环来遍历每个批次中的data(image)与target(标签),也就是自定义数据会合的get() 所返回两个数据:image与label
这里每次取得的数据大小取决于之前 torch::data::make_data_loader() 函数中传入的 batch_size 大小
  1. for(auto &batch: *mdataloader){
  2.    auto data = batch.data;
  3.    auto target = batch.target;
  4.    std::cout<<data.sizes()<<target<<std::endl;
  5. }
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

火影

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表