发表于 2024-12-30 02:16:51

基于 DINOv2 模型实现图搜图相似度检索任务

一、DINOv2 模型简介及利用

DINOv2是由Meta AI开发的第二代自监督视觉变更器模型,采用 Vision Transformer (ViT) 架构 。其核心特点是在无需人工标签的情况下,通过自监督学习技术,从海量无标注图像中学习故意义的视觉特征表现,类似于 NLP 领域的自监督 Base 模型,DINOv2 已经具有了对图像的明白能力,和强大的图像特征提取能力,因此它可以作为险些全部计算机视觉任务的骨干模型。
下面是官方演示地点:
   https://dinov2.metademolab.com/demos?category=segmentation
深度估计结果:
https://i-blog.csdnimg.cn/direct/9ca7212e53bb421183ec88797654a1e3.png
语义分割结果:
https://i-blog.csdnimg.cn/direct/ed2a7a75f7ae45268ac3209df772a49f.png
GitHub 开源地点:
   https://github.com/facebookresearch/dinov2
huggingface 模型地点:
   https://huggingface.co/facebook/dinov2-base/tree/main
https://i-blog.csdnimg.cn/direct/64cc0dddaf76443f8c8b07716a05ede3.png
本文借助 DINOv2 强大的特征提取能力,实现图图相似度检索任务,但开始前,起首 先相识一下如何基于 DINOv2 实现图像的相似度计算:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import matplotlib.pyplot as plt
import torch
plt.rcParams['font.sans-serif'] = ['SimHei']
# 生成图像特征
def gen_image_features(processor, model, device, image):
    with torch.no_grad():
      inputs = processor(images=image, return_tensors="pt").to(device)
      outputs = model(**inputs)
      image_features = outputs.last_hidden_state
      image_features = image_features.mean(dim=1)
      return image_features

# 计算两个图像的相似度
def similarity_image(processor, model, device, image1, image2):
    features1 = gen_image_features(processor, model, device, image1)
    features2 = gen_image_features(processor, model, device, image2)
    cos_sim = torch.cosine_similarity(features1, features2, dim=0)
    cos_sim = (cos_sim + 1) / 2
    return cos_sim.item()

def main():
    model_dir = "facebook/dinov2-base"
    processor = AutoImageProcessor.from_pretrained(model_dir)
    model = AutoModel.from_pretrained(model_dir)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    model.to(device)
    image1 = Image.open("img/dog1.jpg")
    image2 = Image.open("img/dog2.jpg")
    similarity = similarity_image(processor, model, device, image1, image2)
    plt.figure()
    plt.axis('off')
    plt.title(f"相似度: {similarity}")
    plt.subplot(1, 2, 1)
    plt.imshow(image1)
    plt.subplot(1, 2, 2)
    plt.imshow(image2)
    plt.show()

if __name__ == '__main__':
    main()
https://i-blog.csdnimg.cn/direct/72db4d9cb018446fb08f05f0167c0e16.png
https://i-blog.csdnimg.cn/direct/e68bb6b25d254873a36a135a706db1cc.png
相似度计算的主要核心就是基于 DINOv2 天生特征向量,图图相似度检索也依赖这一点,不过需要多出一个特征向量的持久化存储端,整体实现架构如下图所示,其中特征向量存储采用 Milvus 数据库。
被检索图像数据集特征提取过程:
https://i-blog.csdnimg.cn/direct/577416ac566f42caa7e5c3631a76be62.png
图像相似度检索过程:
https://i-blog.csdnimg.cn/direct/70bc411cd56944c7b43a00d6d91aa781.png
关于 Milvus 的利用,可以参考下面这篇博客:
   Milvus 向量数据库介绍及利用
二、图图相似度检索实现 - 图像特征持久化

起首准备图像数据集,这里我随便准备了几张猫和狗的图片:
https://i-blog.csdnimg.cn/direct/19dab730dfed473fbe21936f26cbafa8.png
创建 Milvus Collection,其中 DINOv2 特征向量维度为 768 维:
from pymilvus import MilvusClient, DataType

client = MilvusClient("http://192.168.0.5:19530")

schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=768)
schema.add_field(field_name="image_name", datatype=DataType.VARCHAR, max_length=256)
schema.verify()
index_params = client.prepare_index_params()
index_params.add_index(
    field_name="id",
    index_type="STL_SORT"
)
index_params.add_index(
    field_name="vector",
    index_type="IVF_FLAT",
    metric_type="L2",
    params={"nlist": 1024}
)
# 创建 collection
client.create_collection(
    collection_name="dinov2_collection",
    schema=schema,
    index_params=index_params
)
图像数据集提取特征向量后,持久化到 Milvus 中:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from pymilvus import MilvusClient
from tqdm import tqdm
import torch
import os

# 生成特征向量
def gen_image_features(processor, model, device, image):
    with torch.no_grad():
      inputs = processor(images=image, return_tensors="pt").to(device)
      outputs = model(**inputs)
      image_features = outputs.last_hidden_state
      image_features = image_features.mean(dim=1)
      return image_features

def main():
    # 创建Milvus客户端
    client = MilvusClient("http://192.168.0.5:19530")
    # 加载模型
    model_dir = "facebook/dinov2-base"
    processor = AutoImageProcessor.from_pretrained(model_dir)
    model = AutoModel.from_pretrained(model_dir)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    model.to(device)
    # 读取数据集
    dataset_path = "./img"
    for image_name in tqdm(os.listdir(dataset_path)):
      image_path = os.path.join(dataset_path, image_name)
      image = Image.open(image_path)
      # 提取特征向量
      features = gen_image_features(processor, model, device, image)
      # 存吃至 milvus
      client.insert(
            collection_name="dinov2_collection",
            data={
                "vector": features,
                "image_name": image_name
            }
      )

if __name__ == '__main__':
    main()
运行后,在 Milvus insight 工具中,可以看到存储的内容:
https://i-blog.csdnimg.cn/direct/6bca8876316944b98cb1083df6dd1b90.png
三、图图相似度检索实现 - 图像特征检索

检索就是拿着当前图像的特征,去向量数据库中检索相似的特征信息,实现过程如下:
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
from pymilvus import MilvusClient
import matplotlib.pyplot as plt
import torch
import os


# 生成特征向量
def gen_image_features(processor, model, device, image):
    with torch.no_grad():
      inputs = processor(images=image, return_tensors="pt").to(device)
      outputs = model(**inputs)
      image_features = outputs.last_hidden_state
      image_features = image_features.mean(dim=1)
      return image_features.tolist()


def main():
    # 创建Milvus客户端
    client = MilvusClient("http://192.168.0.5:19530")
    # 加载模型
    model_dir = "facebook/dinov2-base"
    processor = AutoImageProcessor.from_pretrained(model_dir)
    model = AutoModel.from_pretrained(model_dir)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
    model.to(device)
    # 检索图像, 采用不在Milvus数据集中的图像
    image = Image.open("./img2/dog.jpeg")
    # 提取特征向量
    features = gen_image_features(processor, model, device, image)
    # 特征召回
    results = client.search(
      collection_name="dinov2_collection",
      data=,
      limit=2,
      output_fields=["image_name"],
      search_params={
            "metric_type": "L2",
            "params": {}
      }
    )
    plt.figure()
    plt.axis('off')
    for i, res in enumerate(results):
      image_name = res["entity"]["image_name"]
      image_path = os.path.join("./img", image_name)
      image = Image.open(image_path)
      plt.subplot(1, 2, (i+1))
      plt.imshow(image)
    plt.show()

if __name__ == '__main__':
    main()

测试输入检索图像:
https://i-blog.csdnimg.cn/direct/d6e3f6d4f6f94546a1e663f374d011cd.png#pic_center
召回图像:
https://i-blog.csdnimg.cn/direct/3ddc569f7e9e4e9089ad47e9c0c62794.png
测试输入检索图像:
https://i-blog.csdnimg.cn/direct/8a2ca98cdc4446599603057fccc38f70.png#pic_center
召回图像:
https://i-blog.csdnimg.cn/direct/552fe00ef22e4936acd47a30272e30c3.png

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 基于 DINOv2 模型实现图搜图相似度检索任务