本章将具体先容如何使用libtorch自带的数据加载模块。
自定义数据集
使用自定义数据集
简介
要自定义数据加载模块,需要继承torch::data: ataset这个基类实现派生类。
与pytorch中需要实现初始化函数init,获取函数getitem以及数据集大小函数len类似的是,在libtorch中同样需要处理好初始化函数,get()函数和size()函数。
例程的代码结构
例程中使用了一个图像分类使命来进行先容,使用pytorch官网提供的昆虫分类数据集
遍历图像文件
例程中使用了io.h来遍历文件夹。
起首实现遍历文件夹的函数:
担当数据集文件夹路径image_dir和图片类型image_type,将遍历到的图片路径和其种别分别存储到list_images和list_labels,末了lable变量用于表示种别计数。
通过该函数,会得到所有图像的绝对地址,通过这些地址就可以获得图像。
- #include <io.h>
- void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);
- void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label)
- {
- /*
- * path:文件夹地址
- * type:图片类型
- * list_images:所有图片的名称
- * list_label:各个图片的标签,也就是所属的类
- * label:类别的个数
- */
- long long hFile = 0; //句柄
- struct _finddata_t fileInfo;// 记录读取到文件的信息
- std::string pathName;
- // 调用_findfirst函数,其第一个参数为遍历的文件夹路径,*代表任意文件。注意路径最后,需要添加通配符
- // 如果失败,返回-1,否则,就会返回文件句柄,并且将找到的第一个文件信息放在_finddata_t结构体变量中
- if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
- {
- return;
- }
- // 通过do{}while循环,遍历所有文件
- do
- {
- const char* filename = fileInfo.name;// 获得文件名
- const char* t = type.data();
- if (fileInfo.attrib&_A_SUBDIR) //是子文件夹
- {
- //遍历子文件夹中的文件(夹)
- if (strcmp(filename, ".") == 0 || strcmp(filename, "..") == 0) //子文件夹目录是.或者..
- continue;
- std::string sub_path = path + "\" + fileInfo.name;// 增加多一级
- label++;
- load_data_from_folder(sub_path, type, list_images, list_labels, label);// 读取子文件夹的文件
- }
- else //判断是不是后缀为type文件
- {
- if (strstr(filename, t))
- {
- std::string image_path = path + "\" + fileInfo.name;// 构造图像的地址
- list_images.push_back(image_path);
- list_labels.push_back(label);
- }
- }
- //其第一个参数就是_findfirst函数的返回值,第二个参数同样是文件信息结构体
- } while (_findnext(hFile, &fileInfo) == 0);
- return;
- }
复制代码 自定义DataSet
需要继承torch::data: ataset,定义私有变量image_paths和labels分别存储图片路径和种别,是两个vector变量。
在构造函数中,调用图像遍历函数来获得所有图像的地址与种别;并且需要重写get()与size()。
在get()中根据传入的index来获得指定的图像,而且可以在get()函数中对图像进行一些处理,比方调解大小或数据加强等。然后使用torch::from_blob将图像数据与label都转换为张量。
此中图像还需要使用permute(),将张量转换为Channels x Height x Width的结构、
- class myDataset:public torch::data::Dataset<myDataset>{
- public:
- int num_classes = 0;
- myDataset(std::string image_dir, std::string type){
- // 调用遍历文件的函数
- load_data_from_folder(image_dir, std::string(type), image_paths, labels, num_classes);
- }
- // 重写 get(),根据传入的index来获得指定的数据
- torch::data::Example<> get(size_t index) override{
- std::string image_path = image_paths.at(index);// 根据index得到指定的图像
- cv::Mat image = cv::imread(image_path);// 读取图像
- cv::resize(image, image, cv::Size(224, 224));// 调整大小,使得尺寸统一,用于张量stack
- int label = labels.at(index);//
- // 将图像数据转换为张量image_tensor,尺寸{image.rows, image.cols, 3},元素的数据类型为byte
- // Channels x Height x Width
- torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 });
- //
- torch::Tensor label_tensor = torch::full({ 1 }, label);
- return {img_tensor.clone(), label_tensor.clone()};// 返回图像及其标签
- }
- // Return the length of data
- torch::optional<size_t> size() const override {
- return image_paths.size();
- };
- private:
- std::vector<std::string> image_paths;// 所有图像的地址
- std::vector<int> labels;// 所有图像的类别
- };
复制代码 使用自定义数据集
起首创建一个自定义数据集对象,然后它进行一些transform处理
- auto mydataset = myDataset(image_dir,".jpg").map(torch::data::transforms::Stack<>());
复制代码 然后需要使用torch::data::make_data_loader来传入批数据(Batch),对应于pytorch中的torch.utils.data.DataLoader
这内里的 SequentialSampler 类负责按照我们提供的数据顺序来生成样本。
需要传入数据集对象与批次尺寸。
- auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
- std::move(mydataset), batch_size);
复制代码 然后可以通过循环来遍历每个批次中的data(image)与target(标签),也就是自定义数据会合的get() 所返回两个数据:image与label
这里每次取得的数据大小取决于之前 torch::data::make_data_loader() 函数中传入的 batch_size 大小。
- for(auto &batch: *mdataloader){
- auto data = batch.data;
- auto target = batch.target;
- std::cout<<data.sizes()<<target<<std::endl;
- }
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |