从零开始搭建CLIP模型实现基于文本的图像检索

鼠扑  论坛元老 | 2025-4-21 03:14:28 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 2015|帖子 2015|积分 6045

CLIP原理简介

   论文链接,源码链接
  CLIP模型由OpenAI在2021年提出,利用双Decoder(Dual Encoder)的架构来学习图像和文本之间的对应关系,是多模态大模型的开创之作,为后续许多高效的多模态模型的提出打下基础。CLIP是一个预练习模型(Pre-trained Model),在学习到图像–文本特性之间的关联后可以迁移到各种下游任务中,如图像分类,文本引导图像分割和目标检测,图像文本检索等。由于模型学习到的是文本语义和图像语义之间的关联,使得其zero-shot能力非常强盛,根据论文中的描述,CLIP在许多数据集上zero-shot的结果乃至超越了许多练习好的模型的效果。CLIP的练习范式如下:

CLIP的布局非常简单,数据集包含大量的图像文本对,图像经过图像编码器得到图像特性,文本经过文本编码器得到文本特性,将图像特性和文本特性按照数据集中的对应关系进行配对,不配对的特性给予惩罚,从上图中可以看出,我们希望矩阵中蓝色的值趋近于1,别的值趋近于0,采用对比学习的方式对模型进行练习,算法的伪代码如下:

从损失函数中可以看出,分别对特性对比矩阵的行和列进行交叉熵损失函数计算,并取平均得到终极的loss。图像编码器一样寻常有两种选择:ResNet和ViT;文本编码器采用Transformer Encoder,均是各自范畴中优秀的特性提取网络。
CLIP的推理范式如下:

在推理阶段,图像编码器中输入图像获取图像特性,文本编码器中输入文本获取文本特性,将图像特性向量和文本特性向量的转置相乘得到每张图像对每个文本的特性相似度,相似度最高的文本即描述了该图像中物体所属的种别。
代码实现

   Flickr8k数据集下载,提取码:fbfz
DistilBert模型文件下载
  我的运行环境:
CUDA 11.8
pytorch 2.2.2
transformers 4.44.0 # 用于从HuggingFace上加载预练习模型

数据集预览:

   图片示例  

   文本示例  
由于作者的显卡算力有限,选取Flickr8k数据集进行模型练习,其中包含8k个图像文本对,其中一张图像对应5条文本。图像编码器采用ResNet50,直接从timm库中导入;文本编码器采用DistilBert,即轻量化的Bert模型,从HuggingFace上下载。闲话少说,小二,上菜!
  1. ### 模型参数配置 ###
  2. import argparse
  3. from dataclasses import dataclass
  4. parser = argparse.ArgumentParser(description="CLIP from zero")
  5. parser.add_argument("--image_dir", default="user/Flickr8k/Images", help='path to image folder')  # 存放图像的文件路径
  6. parser.add_argument("--caption_dir", default="user/Flickr8k", help='path to caption folder')  # 存放文本的文件路径
  7. parser.add_argument("--weight_dir", default='user/checkpoints', help='path to save output weight')  # 存放训练权重的文件路径
  8. args = parser.parse_args()
  9. @dataclass
  10. class CLIPConfig:
  11.     image_path: str = args.image_dir  # 图像存放路径
  12.     image_size: int = 224  # resize后的图像尺寸,便于构建Dataloader
  13.     caption_path: str = args.caption_dir  # 文本存放路径
  14.     batch_size: int = 8  # 一个批次中的数据数量
  15.     epochs: int = 3  # 训练世代
  16.     image_encoder_model: str = "resnet50"  # 图像编码器的名称
  17.     image_embedding_dim: int = 2048  # 图像特征的维度
  18.     text_encoder_model: str = "distilbert-base-uncased"  # 文本编码器的名称
  19.     text_embedding_dim: int = 768  # 文本特征的维度
  20.     text_tokenizer: str = text_encoder_model  # 文本分词器模型的名称
  21.     max_length: int = 200  # 文本编码器可输入的最长文本长度
  22.     pretrained: bool = False  # 是否加载预训练好的编码器
  23.     trainable: bool = True  # 在训练过程中是否更新编码器的参数
  24.     temperature: float = 1.0  # 计算loss时的正则化系数
  25.     proj_dim: int = 256  # 图像特征和文本特征统一后的维度
  26.     dropout_rate: float = 0.1  # dropout系数,避免过拟合
  27. ### 载入数据集并初始化 ###
  28. import torch
  29. from torch.utils.data import Dataset, DataLoader
  30. from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
  31. import albumentations as A
  32. import pandas as pd
  33. import cv2
  34. class CLIPDataset(Dataset):
  35.     def __init__(self, config, image_path, caption_path, transforms=True):
  36.         """
  37.         图片文件名和标题的长度必须相同
  38.         如果一个图片对应多个标题,该图片文件名需要重复多次
  39.         """
  40.         self.image_path = image_path  # 图像路径
  41.         self.caption_path = caption_path  # 文本路径
  42.         self.dataframe = pd.read_csv(f"{self.caption_path}/captions.csv")  # 读取文本
  43.         self.tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)  # 载入分词器
  44.         self.image_filenames = self.dataframe["image"].values  # 获取图像文件名
  45.         self.captions = list(self.dataframe["caption"].values)   # 获取图像对应的描述文本
  46.         self.encoded_captions = self.tokenizer(self.captions,
  47.                                                padding=True,
  48.                                                truncation=True,
  49.                                                max_length=config.max_length)  # 文本分词
  50.         self.transforms = transforms  # 对输入图像进行预处理
  51.     def __getitem__(self, idx):  # 获取数据集中第idx个数据,其中包含图片名称和对应的标题(可能不止一个)
  52.         item = {
  53.             key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()
  54.         }
  55.         image = cv2.imread(f"{self.image_path}/{self.image_filenames[idx]}")  # 获取原始图像
  56.         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  57.         if self.transforms:
  58.             image = self.get_transforms(mode="train")(image=image)["image"]  # 对图像进行预处理
  59.         item["image"] = torch.tensor(image).permute(2, 0, 1).float()  # 将图片转换为tensor格式,并调整为RGB顺序
  60.         item["caption"] = self.captions[idx]  # 获取标题
  61.         return item
  62.     def __len__(self):
  63.         return len(self.captions)  # 获取文本长度
  64.     def get_transforms(self, mode="train"):
  65.         if mode == "train":
  66.             return A.Compose(
  67.                 [
  68.                     A.Resize(config.image_size, config.image_size, always_apply=True),  # 对图像进行resize
  69.                     A.Normalize(max_pixel_value=255.0, always_apply=True)  # 对像素值进行归一化
  70.                 ]
  71.             )
  72. ### 图像编码器 ###
  73. import torch.nn as nn
  74. import timm
  75. class ImageEncoder(nn.Module):
  76.     """
  77.     图像编码器,采用ResNet50
  78.     """
  79.     def __init__(self, config):
  80.         super().__init__()
  81.         self.model = timm.create_model(config.image_encoder_model,
  82.                                        pretrained=config.pretrained,
  83.                                        num_classes=0, global_pool="avg")  # 创建ResNet50
  84.         for p in self.model.parameters():
  85.             p.requires_grad = config.trainable  # 设置参数可训练
  86.     def forward(self, x):
  87.         image_encoded = self.model(x)  # 获得图像特征编码,形状为[batch_size, image_embedding_dim]
  88.         return image_encoded
  89. ### 文本编码器 ###
  90. class TextEncoder(nn.Module):
  91.     """
  92.     文本编码器,采用DistilBERT
  93.     """
  94.     def __init__(self, config):
  95.         super().__init__()
  96.         if config.pretrained:
  97.             self.model = DistilBertModel.from_pretrained(config.text_encoder_model)  # 导入下载好的模型文件
  98.         else:
  99.             self.model = DistilBertModel(DistilBertConfig())
  100.         for p in self.model.parameters():
  101.             p.requires_grad = config.trainable  # 设置参数可训练
  102.         self.target_token_idx = 0
  103.    
  104.     # 提取出和图像对应的文本特征向量
  105.     def forward(self, input_ids, attention_mask):
  106.         output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  107.         text_encoded = output.last_hidden_state[:, self.target_token_idx, :]  # [batch_size, text_embedding_dim]
  108.         return text_encoded
  109. ### 投影层 (MLP) ###
  110. class ProjectionHead(nn.Module):
  111.     """
  112.     将图像编码和文本编码映射到相同维度
  113.     """
  114.     def __init__(self, config, input_embedding_dim):
  115.         super().__init__()
  116.         self.proj = nn.Linear(input_embedding_dim, config.proj_dim)
  117.         self.act_fn = nn.GELU()
  118.         self.fc = nn.Linear(config.proj_dim, config.proj_dim)
  119.         self.dropout = nn.Dropout(config.dropout_rate)
  120.         self.layer_norm = nn.LayerNorm(config.proj_dim)
  121.     def forward(self, x):
  122.         x_proj = self.proj(x)
  123.         x = self.act_fn(x_proj)
  124.         x = self.fc(x)
  125.         x = self.dropout(x)
  126.         x = x + x_proj
  127.         x = self.layer_norm(x)
  128.         return x
  129. ### 定义损失函数 ###
  130. def cross_entropy(logits, labels, reduction='none'):
  131.     log_softmax = nn.LogSoftmax(dim=-1)
  132.     loss = (-labels * log_softmax(logits)).sum(dim=1)
  133.     if reduction == 'mean':
  134.         return loss.mean()
  135.     else:
  136.         return loss.sum()
  137. ### 模型主体 ###
  138. import torch.nn.functional as F
  139. class CLIP(nn.Module):
  140.     def __init__(self, config):
  141.         super().__init__()
  142.         self.image_encoder = ImageEncoder(config)  # 实例化图像编码器
  143.         self.text_encoder = TextEncoder(config)  # 实例化文本编码器
  144.         self.image_proj = ProjectionHead(config, config.image_embedding_dim)  # 图像特征投影
  145.         self.text_proj = ProjectionHead(config, config.text_embedding_dim)  # 文本特征投影
  146.         self.temperature = config.temperature
  147.     def forward(self, batch):
  148.         image_features = self.image_encoder(batch["image"])  # 图像编码
  149.         
  150.         # 文本编码,tokenizer处理后的文本序列自带input_ids和attention_mask
  151.         text_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])
  152.         image_embeddings = self.image_proj(image_features)  # 图像特征投影
  153.         text_embeddings = self.text_proj(text_features)  # 文本特征投影
  154.         logits = (text_embeddings @ image_embeddings.T) / self.temperature  # tensor形状为[batch_size, batch_size]
  155.         images_similarity = image_embeddings @ image_embeddings.T  # tensor形状为[batch_size, batch_size]
  156.         text_similarity = text_embeddings @ text_embeddings.T  # tensor形状为[batch_size, batch_size]
  157.         # 软标签,不配对的位置设置为较小的数,而非0
  158.         labels = F.softmax((images_similarity + text_similarity) / 2 * self.temperature, dim=-1)  
  159.         
  160.         loss_T = cross_entropy(logits, labels)  # 计算文本损失
  161.         loss_I = cross_entropy(logits.T, labels.T)  # 计算图像损失
  162.         total_loss = (loss_T + loss_I) / 2  # 对比学习平均损失
  163.         return total_loss, logits
  164. ### 训练函数 ###
  165. def train(model, optimizer, scheduler, train_loader, device):
  166.     model.train()  # 模型设置为训练模式
  167.     train_loss = 0
  168.     train_loader = tqdm(train_loader, total=len(train_loader))  # 显示训练进度条
  169.     cnt = 0
  170.     for batch in train_loader:
  171.         # print(batch.keys())
  172.         cnt += 1
  173.         batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}  # 将dataloader中一个batch的数据转换为字典形式
  174.         loss, _ = model(batch)
  175.         optimizer.zero_grad()
  176.         loss.backward()
  177.         optimizer.step()
  178.         scheduler.step(metrics=loss.item())  # 根据上次训练的损失更新学习率
  179.         train_loss += loss.item()
  180.         # 训练100个batch显示一次loss
  181.         if cnt % 100 == 0:
  182.             print(f' ==> Epoch: {epoch + 1}, Batch: {cnt}, Loss: {loss.item():.4f}')
  183.     return train_loss / len(train_loader)  # 平均训练损失
  184. ### 测试函数 ###
  185. def eval(model, val_loader, device):
  186.     model.eval()  # 模型设置为测试模式
  187.     val_loss = 0
  188.     val_loader = tqdm(val_loader, total=len(val_loader))
  189.     with torch.no_grad():
  190.         for batch in val_loader:
  191.             batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
  192.             loss, _ = model(batch)
  193.             val_loss += loss.item()
  194.     return val_loss / len(val_loader)  # 平均测试损失
  195. if __name__ == '__main__':
  196.     config = CLIPConfig()  # 实例化配置信息
  197.     model = CLIP(config)  # 实例化CLIP模型
  198.     device = "cuda" if torch.cuda.is_available() else "cpu"
  199.     model = model.to(device)
  200.     # 查看模型的总参数量
  201.     total_params = sum(p.numel() for p in model.parameters())
  202.     print(f"Total parameters: {total_params / 1e6} M")
  203.     optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
  204.     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=2, factor=0.5)
  205.     dataset = CLIPDataset(config, args.image_dir, args.caption_dir)  # 读取并预处理数据
  206.     train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])  # 80%为训练数据,20%为测试数据
  207.     dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=False)
  208.     train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
  209.     val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
  210.     # 开始训练
  211.     best_loss = float("inf")
  212.     for epoch in range(config.epochs):
  213.         print(f"Epoch: {epoch + 1}")
  214.         train_loss_avg = train(model, optimizer, scheduler, train_loader, device)
  215.         val_loss_avg = eval(model, val_loader, device)
  216.         if val_loss_avg < best_loss:
  217.             best_loss = val_loss_avg
  218.             torch.save(model.state_dict(), f'{args.weight_dir}' + f'/CLIP_{epoch}.pth')
  219.             print("Best model saved!")
  220.     # 图像文本检索推理并可视化
  221.     # dataframe = pd.read_csv(f"{config.caption_path}/captions.csv")
  222.     # tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
  223.     # model.load_state_dict(torch.load(f'{args.weight_dir}' + f'/CLIP_1.pth', map_location=device))
  224.     # model.eval()
  225.     #
  226.     # image_embeddings = []
  227.     # with torch.no_grad():
  228.     #     for batch in tqdm(dataloader):
  229.     #         image_features = model.image_encoder(batch["image"].to(device))  # 获取图像特征
  230.     #         cur_image_embeddings = model.image_proj(image_features)  # [batch_size, proj_dim]  # 图像特征投影
  231.     #         image_embeddings.append(cur_image_embeddings)  # 将一个batch的图像特征保存
  232.     #
  233.     # image_embeddings = torch.cat(image_embeddings, dim=0)  # [image_number, proj_dim]
  234.     # input_query = "two dogs sitting on the grass"  # 输入文本
  235.     # image_filenames = dataframe["image"].values  # 待检索的图片
  236.     #
  237.     # encoded_query = tokenizer([input_query])  # 对输入文本进行分词
  238.     # batch = {key: torch.tensor(values).to(device) for key, values in encoded_query.items()}
  239.     #
  240.     # with torch.no_grad():
  241.     #     text_features = model.text_encoder(batch["input_ids"], batch["attention_mask"])  # 获取文本特征
  242.     #     text_embeddings = model.text_proj(text_features)  # 文本特征投影,与图像特征维度一致
  243.     #
  244.     # image_embeddings_n = F.normalize(image_embeddings, dim=-1)  # [image_number, proj_dim]
  245.     # text_embeddings_n = F.normalize(text_embeddings, dim=-1)  # [1, proj_dim]
  246.     # dot_similarity = text_embeddings_n @ image_embeddings_n.T  # 输入文本的特征和数据集中每张图像特征之间的相似度
  247.     #
  248.     # values, indices = torch.topk(dot_similarity.squeeze(0), k=45)  # 获取前45个相似度最高的图像
  249.     # matches = [image_filenames[idx] for idx in indices[::5]]  # 获取对应的图像文件名(9张图像)
  250.     #
  251.     # f, axes = plt.subplots(3, 3, figsize=(10, 10))
  252.     # f.suptitle(f"Retrieving text: {input_query}")  # 设置主标题
  253.     # for match, ax in zip(matches, axes.flatten()):  # 显示检索出的图像
  254.     #     image = cv2.imread(f"{args.image_dir}/{match}")
  255.     #     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  256.     #     ax.imshow(image)
  257.     #     ax.axis("off")
  258.     #
  259.     # plt.show()
复制代码
理想结果:

参考链接

https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2/

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

鼠扑

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表