【LLM论文日更】| 俄罗斯套娃嵌入模型

饭宝  论坛元老 | 2024-9-23 10:01:10 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1041|帖子 1041|积分 3123

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x

  • 论文:https://proceedings.neurips.cc/paper_files/paper/2022/file/c32319f4868da7613d78af9993100e42-Paper-Conference.pdf
  • 代码:GitHub - RAIVNLab/MRL: Code repository for the paper - "Matryoshka Representation Learning"
  • 机构:McGill University, Mila ServiceNow Research ,Facebook CIFAR AI Chair
  • 领域:embedding model
  • 发表:NeurIPS 2022
研究背景


  • 研究标题:这篇文章要办理的标题是如何设计一种灵活的表示学习方法,使其可以大概适应多个卑鄙使命,而且可以大概根据使命的盘算资源需求进行调整。
  • 研究难点:该标题的研究难点包括:现有固定容量的表示在学习新使命时大概过度或不敷;如何在保持准确性的前提下,明显减少表示的大小和盘算本钱;如何扩展表示学习方法以适应不同模态(如视觉、语言)和数据规模(如网页规模)。
  • 干系工作:该标题的研究干系工作包括大规模数据集上的通用表示学习(如ImageNet和JFT),对比学习(如Contrastive Learning),以及天然语言处理中的预训练模型(如BERT)。这些工作通常依赖于独立的低维模型、子网络优化或后处理压缩来实现表示的灵活性,但这些方法在训练/维护开销、多次前向传播、存储和内存本钱等方面存在不敷。
研究方法

这篇论文提出了Matryoshka Representation Learning(MRL)用于办理表示学习中的灵活性标题。具体来说,

  • 多粒度表示:MRL通过显式优化嵌套的O(log(d))个低维向量,在高维向量中捕获多粒度信息。每个嵌入的前几个维度是一个信息丰富的低维向量,随着维度的增加,表示逐渐变得粗糙。

优化目的:MRL的目的是学习一个d维表示向量z∈Rd,使得每个嵌套维度m∈M都能独立地作为数据点x的可迁徙通用表示。优化目的是使用标准履历风险最小化方法,通过单独的线性分类器来优化每个嵌套维度的多类分类丧失。

其中,L是多类softmax交叉熵丧失函数,cm​是相对重要性权重。
3. 高效实现:为了提高效率,MRL采用了权重绑定技术,即全部线性分类器的权重类似,从而减少内存本钱。这种变体称为Efficient Matryoshka Representation Learning(MRL-E)。
实现代码为:
 
  1. class MRL_Linear_Layer(nn.Module):
  2.         def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
  3.                 super(MRL_Linear_Layer, self).__init__()
  4.                 self.nesting_list = nesting_list
  5.                 self.num_classes = num_classes # Number of classes for classification
  6.                 self.efficient = efficient
  7.                 if self.efficient:
  8.                         setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))               
  9.                 else:       
  10.                         for i, num_feat in enumerate(self.nesting_list):
  11.                                 setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))       
  12.         def reset_parameters(self):
  13.                 if self.efficient:
  14.                         self.nesting_classifier_0.reset_parameters()
  15.                 else:
  16.                         for i in range(len(self.nesting_list)):
  17.                                 getattr(self, f"nesting_classifier_{i}").reset_parameters()
  18.         def forward(self, x):
  19.                 nesting_logits = ()
  20.                 for i, num_feat in enumerate(self.nesting_list):
  21.                         if self.efficient:
  22.                                 if self.nesting_classifier_0.bias is None:
  23.                                         nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
  24.                                 else:
  25.                                         nesting_logits += (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
  26.                         else:
  27.                                 nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
  28.                 return nesting_logits
复制代码
借用一张图,很直观:



实验设计


  • 数据集:实验使用了多个大规模数据集,包括ImageNet-1K、JFT-300M和ALIGN数据集。对于视觉使命,使用了ResNet50和ViT-B/16模型;对于视觉+语言使命,使用了ALIGN模型;对于语言使命,使用了BERT模型。
  • 实验设置:实验中,MRL和MRL-E模型与独立训练的低维表示(FF)、降维(SVD)、子网络方法(slimmable networks)和随机选择的高容量特征进行比较。实验评估了线性分类/探测(LP)和1-近来邻(1-NN)准确性。
  • 参数配置:实验中使用的超参数与独立训练的基线模型类似。例如,ResNet50输出2048维表示,ViT-B/16和BERT-Base输出768维嵌入。
本文将MRL/MRL-E模型与单独训练的低维表征(FF),SVD分解,子网络[2]方法进行了比较
首先是分类使命。对于在ImageNet上训练的模型,线性分类准确率基本和FF保持同等,1-NN准确率以致在低维时高于FF。


对于大规模数据集上训练的模型也取得了很好的精度与速度间的平衡


对于适应性分类,期望的表征大小相比FF减小了14倍。


图像检索的结果也超越了baseline,最高凌驾了FF 3%。适应性图像检索也到达了效率和精度的权衡,16维度做粗排,2048维度做精排的准确率已经和直接使用2048维度做排序的精度还高,但盘算量大幅减小。值得一提的是本文提出了一个漏斗检索方法,纵然用逐渐增大的维度16-32-64-128-256-2048 对前200-100-50-25-10个样本的渐渐重排,这种方法可以省去调参,应用比较方便。

不敷与反思


  • 嵌套丧失权重的优化:未来的工作可以探索自适应丧失平衡方法,以实现更优的准确性-效率权衡。
  • 不同保真度的丧失函数:可以考虑使用针对不同保真度的丧失函数,以办理特定方面的自适应部署标题,例如高召回率的8维表示和鲁棒的2048维表示。
  • 搜索数据结构的集成:可以在MRL上学习一个可微分的k-d树,以实现数据集和表示感知的检索。
  • 多目的MRL的联合优化:结合端到端可学习的搜索数据结构,进行数据驱动的自适应大规模检索,实用于Web规模的搜索应用。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

饭宝

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表