【数据分析】coco格式数据生成yolo数据可视化

打印 上一主题 下一主题

主题 1004|帖子 1004|积分 3022

yolo的数据可视化很详细,coco格式没有。所以写了一个接口。
输入:coco格式的instances.json
输出:生成像yolo那样的标注文件统计并可视化
  1. import os
  2. import random
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. import seaborn as sn
  8. from glob import glob
  9. from PIL import Image, ImageDraw
  10. import json
  11. """
  12. 功能:
  13.     读取instances.json
  14.     生成像yolo那样的标注文件统计并可视化
  15.    
  16. """
  17. def convert(size, box):
  18.     # size(img_width, img_height)
  19.     # box=[x_min, y_min, width, height]
  20.     # coco转yolo   
  21.     dw = 1. / (size[0])
  22.     dh = 1. / (size[1])
  23.     x = box[0] + box[2] / 2.0
  24.     y = box[1] + box[3] / 2.0
  25.     w = box[2]
  26.     h = box[3]
  27.     #round函数确定(xmin, ymin, xmax, ymax)的小数位数
  28.     x = round(x * dw, 6)
  29.     w = round(w * dw, 6)
  30.     y = round(y * dh, 6)
  31.     h = round(h * dh, 6)
  32.     return (x, y, w, h)
  33. def plot_labels(labels, names=(), save_dir='',colors=[0,0,255]):
  34.     # plot dataset labels
  35.     print('Plotting labels... ')
  36.     c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
  37.     nc = int(c.max() + 1)  # number of classes
  38.     x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  39.     # seaborn correlogram
  40.     sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  41.     plt.savefig(os.path.join(save_dir, 'labels_correlogram.jpg'), dpi=200)
  42.     plt.close()
  43.     # matplotlib labels
  44.     matplotlib.use('svg')  # faster
  45.     ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  46.     y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  47.     # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)]  # update colors bug #3195
  48.     ax[0].set_ylabel('instances')
  49.     if 0 < len(names) < 30:
  50.         ax[0].set_xticks(range(len(names)))
  51.         ax[0].set_xticklabels(names, rotation=90, fontsize=10)
  52.     else:
  53.         ax[0].set_xlabel('classes')
  54.     sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  55.     sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  56.     # rectangles
  57.     labels[:, 1:3] = 0.5  # center
  58.     labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  59.     img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  60.     for cls, *box in labels[:1000]:
  61.         ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls)-1])  # plot
  62.     ax[1].imshow(img)
  63.     ax[1].axis('off')
  64.     for a in [0, 1, 2, 3]:
  65.         for s in ['top', 'right', 'left', 'bottom']:
  66.             ax[a].spines[s].set_visible(False)
  67.     plt.savefig(os.path.join(save_dir, 'labels.jpg'), dpi=200)
  68.     matplotlib.use('Agg')
  69.     plt.close()
  70. def xywh2xyxy(x):
  71.     # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  72.     y = np.copy(x)
  73.     y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
  74.     y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
  75.     y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
  76.     y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
  77.     return y
  78. def main(json_name,save_root,data_name):
  79.     # 获取当前数据集中所有json文件
  80.    
  81.     with open(json_name, 'r', encoding='utf-8') as file:
  82.         result = json.load(file)
  83.     # 每个类别一个颜色
  84.     category=[]
  85.     for i in result['categories']:
  86.         category.append(i['name'])# 类别
  87.     num_classes = len(category)  # 类别数
  88.     colors = [(random.randint(0,255),random.randint(0,255),random.randint(0,255)) for _ in range(num_classes)]  # 每个类别生成一个随机颜色
  89.     # 统计标注信息
  90.     shapes = []  # 标注框
  91.     ids = []  # 类别名的索引
  92.     for i in result['annotations']:
  93.         img_height=result['images'][i['image_id']-1]['height']
  94.         img_width=result['images'][i['image_id']-1]['width']
  95.         label_id=i['category_id']
  96.         ids.append([label_id])
  97.         (x, y, w, h)=convert([img_width, img_height], i['bbox'])
  98.         shapes.append([x, y, w, h])
  99.     shapes = np.array(shapes)
  100.     ids = np.array(ids)
  101.     lbs = np.hstack((ids, shapes))
  102.     plot_labels(labels=lbs, names=np.array(category),save_dir=os.path.join(save_root,data_name),colors=colors)
  103.     print("可视化已保存:", os.path.join(save_root,data_name, "label.jpg"))
  104. if __name__ == "__main__":
  105.         json_name = os.path.join(path,data_name,'annotations','instances.json')
  106.         save_root='保存路径'
  107.         data_name='数据集的名称'
  108.     main(json_name,save_root,data_name)
复制代码
labels.jpg

labels_correlogram.jpg


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

泉缘泉

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表