基于 DINOv2 模型实现图搜图相似度检索任务

  金牌会员 | 2024-12-30 02:16:51 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 806|帖子 806|积分 2418

一、DINOv2 模型简介及利用

DINOv2是由Meta AI开发的第二代自监督视觉变更器模型,采用 Vision Transformer (ViT) 架构 。其核心特点是在无需人工标签的情况下,通过自监督学习技术,从海量无标注图像中学习故意义的视觉特征表现,类似于 NLP 领域的自监督 Base 模型,DINOv2 已经具有了对图像的明白能力,和强大的图像特征提取能力,因此它可以作为险些全部计算机视觉任务的骨干模型。
下面是官方演示地点:
   https://dinov2.metademolab.com/demos?category=segmentation
  深度估计结果:

语义分割结果:

GitHub 开源地点:
   https://github.com/facebookresearch/dinov2
  huggingface 模型地点:
   https://huggingface.co/facebook/dinov2-base/tree/main
  

本文借助 DINOv2 强大的特征提取能力,实现图图相似度检索任务,但开始前,起首 先相识一下如何基于 DINOv2 实现图像的相似度计算:
  1. from transformers import AutoImageProcessor, AutoModel
  2. from PIL import Image
  3. import matplotlib.pyplot as plt
  4. import torch
  5. plt.rcParams['font.sans-serif'] = ['SimHei']
  6. # 生成图像特征
  7. def gen_image_features(processor, model, device, image):
  8.     with torch.no_grad():
  9.         inputs = processor(images=image, return_tensors="pt").to(device)
  10.         outputs = model(**inputs)
  11.         image_features = outputs.last_hidden_state
  12.         image_features = image_features.mean(dim=1)
  13.         return image_features[0]
  14. # 计算两个图像的相似度
  15. def similarity_image(processor, model, device, image1, image2):
  16.     features1 = gen_image_features(processor, model, device, image1)
  17.     features2 = gen_image_features(processor, model, device, image2)
  18.     cos_sim = torch.cosine_similarity(features1, features2, dim=0)
  19.     cos_sim = (cos_sim + 1) / 2
  20.     return cos_sim.item()
  21. def main():
  22.     model_dir = "facebook/dinov2-base"
  23.     processor = AutoImageProcessor.from_pretrained(model_dir)
  24.     model = AutoModel.from_pretrained(model_dir)
  25.     device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
  26.     model.to(device)
  27.     image1 = Image.open("img/dog1.jpg")
  28.     image2 = Image.open("img/dog2.jpg")
  29.     similarity = similarity_image(processor, model, device, image1, image2)
  30.     plt.figure()
  31.     plt.axis('off')
  32.     plt.title(f"相似度: {similarity}")
  33.     plt.subplot(1, 2, 1)
  34.     plt.imshow(image1)
  35.     plt.subplot(1, 2, 2)
  36.     plt.imshow(image2)
  37.     plt.show()
  38. if __name__ == '__main__':
  39.     main()
复制代码


相似度计算的主要核心就是基于 DINOv2 天生特征向量,图图相似度检索也依赖这一点,不过需要多出一个特征向量的持久化存储端,整体实现架构如下图所示,其中特征向量存储采用 Milvus 数据库。
被检索图像数据集特征提取过程:

图像相似度检索过程:

关于 Milvus 的利用,可以参考下面这篇博客:
   Milvus 向量数据库介绍及利用
  二、图图相似度检索实现 - 图像特征持久化

起首准备图像数据集,这里我随便准备了几张猫和狗的图片:

创建 Milvus Collection,其中 DINOv2 特征向量维度为 768 维:
  1. from pymilvus import MilvusClient, DataType
  2. client = MilvusClient("http://192.168.0.5:19530")
  3. schema = MilvusClient.create_schema(
  4.     auto_id=True,
  5.     enable_dynamic_field=False,
  6. )
  7. schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
  8. schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=768)
  9. schema.add_field(field_name="image_name", datatype=DataType.VARCHAR, max_length=256)
  10. schema.verify()
  11. index_params = client.prepare_index_params()
  12. index_params.add_index(
  13.     field_name="id",
  14.     index_type="STL_SORT"
  15. )
  16. index_params.add_index(
  17.     field_name="vector",
  18.     index_type="IVF_FLAT",
  19.     metric_type="L2",
  20.     params={"nlist": 1024}
  21. )
  22. # 创建 collection
  23. client.create_collection(
  24.     collection_name="dinov2_collection",
  25.     schema=schema,
  26.     index_params=index_params
  27. )
复制代码
图像数据集提取特征向量后,持久化到 Milvus 中:
  1. from transformers import AutoImageProcessor, AutoModel
  2. from PIL import Image
  3. from pymilvus import MilvusClient
  4. from tqdm import tqdm
  5. import torch
  6. import os
  7. # 生成特征向量
  8. def gen_image_features(processor, model, device, image):
  9.     with torch.no_grad():
  10.         inputs = processor(images=image, return_tensors="pt").to(device)
  11.         outputs = model(**inputs)
  12.         image_features = outputs.last_hidden_state
  13.         image_features = image_features.mean(dim=1)
  14.         return image_features[0]
  15. def main():
  16.     # 创建Milvus客户端
  17.     client = MilvusClient("http://192.168.0.5:19530")
  18.     # 加载模型
  19.     model_dir = "facebook/dinov2-base"
  20.     processor = AutoImageProcessor.from_pretrained(model_dir)
  21.     model = AutoModel.from_pretrained(model_dir)
  22.     device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
  23.     model.to(device)
  24.     # 读取数据集
  25.     dataset_path = "./img"
  26.     for image_name in tqdm(os.listdir(dataset_path)):
  27.         image_path = os.path.join(dataset_path, image_name)
  28.         image = Image.open(image_path)
  29.         # 提取特征向量
  30.         features = gen_image_features(processor, model, device, image)
  31.         # 存吃至 milvus
  32.         client.insert(
  33.             collection_name="dinov2_collection",
  34.             data={
  35.                 "vector": features,
  36.                 "image_name": image_name
  37.             }
  38.         )
  39. if __name__ == '__main__':
  40.     main()
复制代码
运行后,在 Milvus insight 工具中,可以看到存储的内容:

三、图图相似度检索实现 - 图像特征检索

检索就是拿着当前图像的特征,去向量数据库中检索相似的特征信息,实现过程如下:
  1. from transformers import AutoImageProcessor, AutoModel
  2. from PIL import Image
  3. from pymilvus import MilvusClient
  4. import matplotlib.pyplot as plt
  5. import torch
  6. import os
  7. # 生成特征向量
  8. def gen_image_features(processor, model, device, image):
  9.     with torch.no_grad():
  10.         inputs = processor(images=image, return_tensors="pt").to(device)
  11.         outputs = model(**inputs)
  12.         image_features = outputs.last_hidden_state
  13.         image_features = image_features.mean(dim=1)
  14.         return image_features[0].tolist()
  15. def main():
  16.     # 创建Milvus客户端
  17.     client = MilvusClient("http://192.168.0.5:19530")
  18.     # 加载模型
  19.     model_dir = "facebook/dinov2-base"
  20.     processor = AutoImageProcessor.from_pretrained(model_dir)
  21.     model = AutoModel.from_pretrained(model_dir)
  22.     device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
  23.     model.to(device)
  24.     # 检索图像, 采用不在Milvus数据集中的图像
  25.     image = Image.open("./img2/dog.jpeg")
  26.     # 提取特征向量
  27.     features = gen_image_features(processor, model, device, image)
  28.     # 特征召回
  29.     results = client.search(
  30.         collection_name="dinov2_collection",
  31.         data=[features],
  32.         limit=2,
  33.         output_fields=["image_name"],
  34.         search_params={
  35.             "metric_type": "L2",
  36.             "params": {}
  37.         }
  38.     )
  39.     plt.figure()
  40.     plt.axis('off')
  41.     for i, res in enumerate(results[0]):
  42.         image_name = res["entity"]["image_name"]
  43.         image_path = os.path.join("./img", image_name)
  44.         image = Image.open(image_path)
  45.         plt.subplot(1, 2, (i+1))
  46.         plt.imshow(image)
  47.     plt.show()
  48. if __name__ == '__main__':
  49.     main()
复制代码
测试输入检索图像:

召回图像:

测试输入检索图像:

召回图像:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表