【DETR】训练自己的数据集以及YOLO数据集格式(txt)转化成COCO格式(json) ...

打印 上一主题 下一主题

主题 2037|帖子 2037|积分 6111

1.DETR介绍

DETR(Detection with TRansformers)是基于transformer的端对端目标检测,无NMS后处理步调,无anchor。
代码链接:https://github.com/facebookresearch/detr

2.数据集处理

DETR需要的数据集格式为coco格式,这里我是用自己的YOLO格式数据集转化成COCO格式,然后进行训练的。
YOLO数据集的组织格式是:
其中images里面分别存放训练集train和验证集val的图片,labels存放训练集train和验证集val的txt标签。

要转化成顺应DETR模子读取的COCO数据集的组织形式是:
其中train2017存放训练集的图片,val2017存放验证集的图片,
annotations文件夹里面存放train和val的json标签。

下面是转化代码:


  • 需要进行类别映射,每个类别对应的id分别存放在categories里面,这里我没有效classes.txt文件存放,相当于直接把classes.txt里面的类别写出来了。
  • 我的图片是png格式的,如果图片是jpg格式的,将png改成jpg即可。image_name = filename.replace(‘.txt’, ‘.jpg’)
  • 最后修改文件路径,改成自己的路径,这里最后会输出train和val的json文件,图片不会处理,按上述目次组织形式将图片组织起来即可。
  • 生成的文件夹记得改为instances_train2017.json这种样子
  1. import os
  2. import json
  3. from PIL import Image
  4. # 定义类别映射
  5. categories = [
  6.     {"id": 0, "name": "Double hexagonal column"},
  7.     {"id": 1, "name": "Flange nut"},
  8.     {"id": 2, "name": "Hexagon nut"},
  9.     {"id": 3, "name": "Hexagon pillar"},
  10.     {"id": 4, "name": "Hexagon screw"},
  11.     {"id": 5, "name": "Hexagonal steel column"},
  12.     {"id": 6, "name": "Horizontal bubble"},
  13.     {"id": 7, "name": "Keybar"},
  14.     {"id": 8, "name": "Plastic cushion pillar"},
  15.     {"id": 9, "name": "Rectangular nut"},
  16.     {"id": 10, "name": "Round head screw"},
  17.     {"id": 11, "name": "Spring washer"},
  18.     {"id": 12, "name": "T-shaped screw"}
  19. ]
  20. def yolo_to_coco(yolo_images_dir, yolo_labels_dir, output_json_path):
  21.     # 初始化 COCO 数据结构
  22.     data = {
  23.         "images": [],
  24.         "annotations": [],
  25.         "categories": categories
  26.     }
  27.     image_id = 1
  28.     annotation_id = 1
  29.     def get_image_size(image_path):
  30.         with Image.open(image_path) as img:
  31.             return img.width, img.height
  32.     # 遍历标签目录
  33.     for filename in os.listdir(yolo_labels_dir):
  34.         if not filename.endswith('.txt'):
  35.             continue  # 只处理 .txt 文件
  36.         image_name = filename.replace('.txt', '.png')# 如果图片是jpg格式的,将png改成jpg即可。
  37.         
  38.         image_path = os.path.join(yolo_images_dir, image_name)
  39.         if not os.path.exists(image_path):
  40.             print(f"⚠️ 警告: 图像 {image_name} 不存在,跳过 {filename}")
  41.             continue
  42.         image_width, image_height = get_image_size(image_path)
  43.         image_info = {
  44.             "id": image_id,
  45.             "width": image_width,
  46.             "height": image_height,
  47.             "file_name": image_name
  48.         }
  49.         data["images"].append(image_info)
  50.         with open(os.path.join(yolo_labels_dir, filename), 'r') as file:
  51.             lines = file.readlines()
  52.         for line in lines:
  53.             parts = line.strip().split()
  54.             if len(parts) != 5:
  55.                 print(f"⚠️ 警告: 标签 {filename} 格式错误: {line.strip()}")
  56.                 continue
  57.             category_id = int(parts[0])
  58.             x_center = float(parts[1]) * image_width
  59.             y_center = float(parts[2]) * image_height
  60.             bbox_width = float(parts[3]) * image_width
  61.             bbox_height = float(parts[4]) * image_height
  62.             x_min = int(x_center - bbox_width / 2)
  63.             y_min = int(y_center - bbox_height / 2)
  64.             bbox = [x_min, y_min, bbox_width, bbox_height]
  65.             area = bbox_width * bbox_height
  66.             annotation_info = {
  67.                 "id": annotation_id,
  68.                 "image_id": image_id,
  69.                 "category_id": category_id,
  70.                 "bbox": bbox,
  71.                 "area": area,
  72.                 "iscrowd": 0
  73.             }
  74.             data["annotations"].append(annotation_info)
  75.             annotation_id += 1
  76.         image_id += 1
  77.     os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
  78.     with open(output_json_path, 'w') as json_file:
  79.         json.dump(data, json_file, indent=4)
  80.     print(f"✅ 转换完成: {output_json_path}")
  81. # 输入路径 (YOLO 格式数据集)
  82. yolo_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0"
  83. yolo_train_images = os.path.join(yolo_base_dir, "images/train")
  84. yolo_train_labels = os.path.join(yolo_base_dir, "labels/train")
  85. yolo_val_images = os.path.join(yolo_base_dir, "images/val")
  86. yolo_val_labels = os.path.join(yolo_base_dir, "labels/val")
  87. # 输出路径 (COCO 格式)
  88. coco_base_dir = "/home/yu/Yolov8/ultralytics-main/mydata0_coco"
  89. coco_train_json = os.path.join(coco_base_dir, "annotations/instances_train.json")
  90. coco_val_json = os.path.join(coco_base_dir, "annotations/instances_val.json")
  91. # 运行转换
  92. yolo_to_coco(yolo_train_images, yolo_train_labels, coco_train_json)
  93. yolo_to_coco(yolo_val_images, yolo_val_labels, coco_val_json)
复制代码
3.转化效果可视化

COCO数据集JSON文件格式分为以下几个字段。
  1. {
  2.     "info": info, # dict
  3.      "licenses": [license], # list ,内部是dict
  4.      "images": [image], # list ,内部是dict
  5.      "annotations": [annotation], # list ,内部是dict
  6.      "categories": # list ,内部是dict
  7. }
复制代码
可以运行以下脚本查看转化后的标签是否与图片目标对应:


  • 修改代码的json_path和img_path,json_path是标签对应的路径,img_path是图像对应的路径
  1. '''
  2. 该代码的功能是:读取图像以及对应bbox的信息
  3. '''
  4. import os
  5. from pycocotools.coco import COCO
  6. from PIL import Image, ImageDraw
  7. import matplotlib.pyplot as plt
  8. json_path = "/home/yu/Yolov8/ultralytics-main/mydata0_coco/annotations/instances_val.json"
  9. img_path = ("/home/yu/Yolov8/ultralytics-main/mydata0_coco/images/val")
  10. # load coco data
  11. coco = COCO(annotation_file=json_path)
  12. # get all image index info
  13. ids = list(sorted(coco.imgs.keys()))
  14. print("number of images: {}".format(len(ids)))
  15. # get all coco class labels
  16. coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])
  17. # 遍历前三张图像
  18. for img_id in ids[:3]:
  19.     # 获取对应图像id的所有annotations idx信息
  20.     ann_ids = coco.getAnnIds(imgIds=img_id)
  21.     # 根据annotations idx信息获取所有标注信息
  22.     targets = coco.loadAnns(ann_ids)
  23.     # get image file name
  24.     path = coco.loadImgs(img_id)[0]['file_name']
  25.     # read image
  26.     img = Image.open(os.path.join(img_path, path)).convert('RGB')
  27.     draw = ImageDraw.Draw(img)
  28.     # draw box to image
  29.     for target in targets:
  30.         x, y, w, h = target["bbox"]
  31.         x1, y1, x2, y2 = x, y, int(x + w), int(y + h)
  32.         draw.rectangle((x1, y1, x2, y2))
  33.         draw.text((x1, y1), coco_classes[target["category_id"]])
  34.     # show image
  35.     plt.imshow(img)
  36.     plt.show()
复制代码
运行该代码,你将会看到你的标签是否对应:
如果目标没有界限框则说明你转化的json不对!


4.数据集训练

4.1修改pth文件

将它的pth文件改一下,由于他是用的coco数据集,而我们只需要训练自己的数据集,就是下图这个文件,这是它本来的

新建一个.py文件,运行下面代码,就会生成一个你数据集所需要的物体数目的pth,记得改类别数!。
  1. import torch
  2. pretrained_weights  = torch.load('detr-r50-e632da11.pth')
  3. num_class = 14 #这里是你的物体数+1,因为背景也算一个
  4. pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
  5. pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
  6. torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class
复制代码
这是我们生成的。

4.2类别参数修改

修改models/detr.py文件,build()函数中,可以将红框部分的代码都注释掉,直接设置num_classes为自己的类别数+1
由于我的类别数是13,以是我这里num_classes=14

4.3训练

修改main.py文件的epochs、lr、batch_size等训练参数:
以下这些参数都在get_args_parser()函数里面。

修改自己的数据集路径:

设置输出路径:

修改resume为自己的预训练权重文件路径
这里就是你刚才运行脚本生成的pth文件的路径:

运行main.py文件
大概可以通过命令行运行:
  1. python main.py --dataset_file "coco" --coco_path "/home/yu/Yolov8/ultralytics-main/mydata0_coco" --epoch 300 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir="outputs" --resume="detr_r50_14.pth"
复制代码
5.乐成运行!


6.参考文献

1.【DETR】训练自己的数据集-实践笔记
2. yolo数据集格式(txt)转coco格式,方便mmyolo转标签格式
3. windows10复现DEtection TRansformers(DETR)并实现自己的数据集

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

正序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

王柳

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表