视频生成模子Sora的全面解析:从AI绘画、ViT到ViViT、DiT、VDT、NaViT、Vid ...

打印 上一主题 下一主题

主题 806|帖子 806|积分 2418


视频生成模子Sora的全面解析:从AI绘画、ViT到ViViT、DiT、VDT、NaViT、VideoPoet

真没想到,举例视频生成上一轮的集中爆发才已往三个月,没想OpenAI一脱手,该范畴又直接变天了自打2.16日OpenAI发布sora以来,不光把同时段Google发布的Gemmi Pro 1.5干没了声音,而且网上各个渠道,大量新闻媒体、自媒体(含公号、微博、博客、视频)做了大量的解读,也引发了圈内外的大量关注,很多人因此认为,视频生成范畴自此进入了大规模应用前夕,比如NLP范畴中GPT3的发布一
​编辑 v_JULY_v  ·  2024-02-18 22:49:27 发布
前言

真没想到,距离视频生成上一轮的集中爆发(详见《Sora之前的视频生成发展史:从Gen2、Emu Video到PixelDance、SVD、Pika 1.0》)才已往三个月,没想OpenAI一脱手,该范畴又直接变天了

  • 自打2.16日OpenAI发布sora以来(其开发团队包罗DALLE 3的4作Tim Brooks、DiT一作Bill Peebles、三代DALLE的焦点作者之一Aditya Ramesh等13人),不光把同时段Google发布的Gemini 1.5干没了声音,而且网上各个渠道,大量新闻媒体、自媒体(含公号、微博、博客、视频)做了大量的解读,也引发了圈内外的大量关注
    很多人因此认为,视频生成范畴自此进入了大规模应用前夕,比如NLP范畴中GPT3的发布
  • 一开始,我还自以为视频生成这玩意对于有场景的人,是庞大利好,比如在影视行业的
    对于没场景的人,只能当热闹看看,而且我司大模子项目开发团队去年年底还考虑过是否做视频生成的应用,但当时想了很久,没找加入景,做别的应用去了
可当我接连扒出sora相关的10多篇论文之后,觉得sora和此前发布的视频生成模子有了质的飞跃(不但是一个60s),而是再次印证了鼎力大举出奇迹,大模子似乎可以在力大砖飞的环境下开始明白物理世界了,使得我司大模子项目组也愿意重新考虑开发视频生成的相关应用
本文主要分为三个部门(初步明白只看第一部门即可,深入明白看第二部门,更多细节则看第三部门)


  • 第一部门,侧重sora的焦点技能解读
    方便各人把握重点,且会比一切新闻稿都更正确,此外
      如果之前没有了解过DDPM、ViT的,发起先阅读下此文《从VAE、扩散模子DDPM、DETR到ViT、Swin transformer》
      如果之前没有了解过图像生成的,发起先阅读下此文《从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion》
    固然,如果个别朋侪实在不想点开看上面的两篇文章,我也尽可能在本文中把相关重点交代清楚
  • 第二部门,侧重sora相近技能的发展演变
    把sora涉及到的关键技能在本文中全部全面、深入、过细的阐述清楚,究竟如果人云亦云就不用我来写了
    且看完这部门你会发现,从来没有任何一个火爆环球的产品是一蹴而就的,且根本都是各种创新技能的集大成者(Google很多工作把transformer等各路技能发扬光大,但OpenAI则把各路技能 整合到极致了..)
  • 第三部门,根据sora的32个reference以窥探其背后的更多细节
    由于sora实在是太火了,网上各种解读非常多,有的很专业,有的看上去一本正经 实则是颠三倒四(即便他的title看起来有一定的程度),为方便各人辨别什么样的解读是不对的,特把一些更深入的细节也介绍下
总之,看本文之前,如果你人云亦云的来一句:sora就是DiT架构,我表示明白。但看完全文后你会发现


  • 如果只答应用10个字定义sora的模子结构,则可以是:潜伏扩散架构下的Video Transformer
  • 如果答应25个字以内,则是:带文本条件融合且时空留意力并行盘算的Video Diffusion Transformer

保举内容

第一部门 OpenAI Sora的关键技能点

1.1 Sora的三大Transformer组件

1.1.1 从前置工作DALLE 2到sora的三大组件

为方便各人更好的明白sora背后的原理,我们先来快速回顾下AI绘画的原理(明白了AI绘画,也就明白了sora一半)
   以DALLE 2为例,如下图所示(以下内容来自此文:从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion)
  

  

  • CLIP训练过程:学习笔墨与图片的对应关系
    如上图所示,CLIP的输入是一对对配对好的的图片-文本对(根据对应文本一条狗,去匹配一条狗的图片),这些文本和图片分别通过Text Encoder和Image Encoder输出对应的特性,然后在这些输出的笔墨特性和图片特性上进行对比学习
  • DALL·E2:prior + decoder
    上面的CLIP训练好之后,就将其冻住了,不再加入任何训练和微调,DALL·E2训练时,输入也是文本-图像对,下面就是DALL·E2的两阶段训练:
      阶段一 prior的训练:根据文本特性(即CLIP text encoder编码后得到的文本特性),预测图像特性(CLIP image encoder编码后得到的图片特性)
    换言之,prior模子的输入就是上面CLIP编码的文本特性,然后利用文本特性预测图片特性(阐明白点,即图中右侧下半部门预测的图片特性的ground truth,就是图中右侧上半部门颠末CLIP编码的图片特性),就完成了prior的训练
    推理时,文本还是通过CLIP text encoder得到文本特性,然后根据训练好的prior得到类似CLIP生成的图片特性,此时图片特性应该训练的非常好,不仅可以用来生成图像,而且和文本接洽的非常紧(包含丰富的语义信息)

      阶段二 decoder生成图:常规的扩散模子解码器,解码生成图像
    这里的decoder就是升级版的GLIDE(GLIDE基于扩散模子),以是说DALL·E2 = CLIP + GLIDE
  以是对于DALLE 2来说,正因为颠末了大量上面这种训练,以是便可以根据人类给定的prompt画出人类预期的画作,说白了,可以根据text预测画作长什么样
终极,sora由三大Transformer组件组成(如果你还不了解transformer或留意力机制,请读此文):Visual Encoder(即Video transformer,类似下文将介绍的ViViT)、Diffusion TransformerTransformer Decoder,具体而言


  • 训练中,给定一个原始视频

      Visual Encoder将视频压缩到较低维的潜伏空间(潜伏空间这个概念在stable diffusion中用的可谓出神入化了,详见此文的第三部门)
      然后把视频分解为在时间和空间上压缩的潜伏表示(不重叠的3D patches),即所谓的一系列时空Patches
      再将这些patches拉平成一个token序列,这个token序列实在就是原始视频的表征:visual token序列
  • Sora 在这个压缩的潜伏空间中接受训练,还是类似扩散模子那一套,先加噪、再去噪
    这里,有两点必须留意的是
      1 扩散过程中所用的噪声估计器U-net被替换成了transformer结构的DiT(加之视觉元素转换成token之后,transformer擅长长距离建模,下文详述DiT)
      2 视频中这一系列帧在上个过程中是同时被编码的,去噪也是一系列帧并行去噪的(每一帧逐步去噪、多帧并行去噪)
    此外,去噪过程中,可以参加去噪的条件(即text condition),这个去噪条件一开始可以是原始视频
    的描述,后续还可以是基于原视频进行二次创作的prompt
    比如可以将visual tokens视为query,将text tokens作为key和value,然后类似SD那样做cross attention
  • OpenAI 还训练了相应的Transformer解码器模子,将生成的潜伏表示映射回像素空间,从而生成视频

   你会发现,上述整个过程,实在和SD的原理是有较大的相似性(SD原理见此文《从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion》的3.2节),固然,不同之处也有很多,比如视频需要一次性还原多帧、图像只需要还原一帧
  

  网上也有不少人画出了sora的架构图,比如来自魔搭社区的
  

  1.1.2 如何明白所谓的时空编码(含其利益)

起首,一个视频无非就是沿着时间轴分布的图像序列而已

但此中有个问题是,因为像素的关系,一张图像有着比力大的维度(比如250 x 250),即一张图片上可能有着5万多个元素,如果根据上一张图片的5万多元素去逐一交互下一张图片的5万多个元素,未免工程过于浩大(而且,即便是同一张图片上的5万多个像素点之间两两做self-attention,你都会发现盘算复杂度超级高)

  • 故为低落处理的复杂度,可以类似ViT把一张图像分别为九宫格(如下图的左下角),如此,处理9个图像块总比一次性处理250 x 250个像素维度 要好不少吧(ViT的出现直接挑衅了此前CNN在视觉范畴长达近10年的绝对统治职位,其原理细节详见本文开头提到的此文第4部门)

  • 当我们明白了一张静态图像的patch表示之后(不管是九宫格,还是16 x 9个格),再来明白所谓的时空Patches就简朴多了,无非就是在纵向上加上时间的维度,比如t1 t2 t3 t4 t5 t6
    而一个时空patch可能跨3个时间维度,固然,也可能跨5个时间维度

    如此,同时间段内不同位置的立方块可以做横向留意力交互——空间编码
    不同时间段内相同位置的立方块则可以做纵向留意力交互——时间编码
    (如果依然还没有特别明白,不要紧,可以再看下下文第二部门中对ViViT的介绍)

可能有同砚问,这么做有什么利益呢?利益太多了


  • 一方面,时空建模之下,不仅提高单帧的流畅、更提高帧与帧之间的流畅,究竟有Transformer的留意力机制,那无论哪一帧图像,各个像素块都不再是孤立的存在,都与周围的元素精密接洽
  • 二方面,可以兼容所有的数据素材:一个静态图像不外是时间=0的一系列时空patch,不同的像素尺寸、不同的时间长短,都可以通过组合一系列 “时空patch” 得到
总之,基于 patches 的表示,使 Sora 能够对不同分辨率、连续时间和长宽比的视频和图像进行训练。在推理时,也可以可以通过在得当大小的网格中分列随机初始化的 patches 来控制生成视频的大小
   DiT 作者之一 Saining Xie 在推文中提到:Sora“可能还利用了谷歌的 Patch n’ Pack (NaViT) 论文结果,使其能够顺应可变的分辨率/连续时间/长宽比”
  
  固然,ViT自己也能够处理任意分辨率(不同分辨率相当于不同长度的图片块序列),但NaViT提供了一种高效训练的方法,关于NaViT的更多细节详见下文的介绍
  而已往的图像和视频生成方法通常需要调解大小、进行裁剪或者是将视频剪切到标准尺寸,比方 4 秒的视频分辨率为 256x256。相反,该研究发如今原始大小的数据上进行训练,终极提供以下利益:

  • 起首是采样的灵活性:Sora 可以采样宽屏视频 1920x1080p,垂直视频 1920x1080p 以及两者之间的视频。这使 Sora 可以直接以其天然纵横比为不同装备创建内容。Sora 还答应在生成全分辨率的内容之前,以较小的尺寸快速创建内容原型 —— 所有内容都利用相同的模子

  • 其次利用视频的原始长宽比进行训练可以提拔内容组成和帧的质量
    其他模子一般将所有训练视频裁剪成正方形,而颠末正方形裁剪训练的模子生成的视频(如下图左侧),此中的视频主题只是部门可见;相比之下,Sora 生成的视频具有改进的帧内容(如下图右侧)

1.1.3 Diffusion Transformer(DiT):扩散过程中以Transformer为骨干网络

sora不是第一个把扩散模子和transformer结合起来用的模子,但是第一个取得巨大乐成的,为何说它是结合体呢

  • 一方面,它类似扩散模子那一套流程,给定输入噪声patches(以及文本提示等调节信息),训练出的模子来预测原始的不带噪声的patches「Sora is a diffusion model, given input noisy patches (and conditioning information like text prompts), it’s trained to predict the original “clean” patches
    类似把视频中的一帧帧画面打上各种马赛克,然后训练一个模子,让它学会去除各种马赛克,且一开始各种失败不要紧,反正有原画面作为ground truth,不断缩小与原画面之间的差异即可
    而当把一帧帧图片打上全部马赛克之后,可以根据”文本-视频数据集”中对视频的描述/prompt(该描述/prompt不仅仅只是通过CLIP去与视频对齐,还颠末类似DALLE 3所用的重字幕技能加强 + GPT4对字幕的进一步丰富,下节详述),而有条件的去噪
  • 二方面,它把DPPM中的噪声估计器所用的卷积架构U-Net换成了Transformer架构

总之,总的来说,Sora是一个在不同时长、分辨率和宽高比的视频及图像上训练而成的扩散模子,同时采用了Transformer架构,如sora官博所说,Sora is a diffusion transformer,简称DiT
关于DiT的更多细节详见下文第二部门介绍的DiT
1.2 基于DALLE 3的重字幕技能:提拔文本-视频数据质量

1.2.1 DALLE 3的重字幕技能:为文本-视频数据集打上字幕且用GPT把字幕详细化

起首,训练文本到视频生成系统需要大量带有相应文本字幕的视频,而通过CLIP技能给视频对齐的文本描述,有时质量较差,故为进一步提高文本-视频数据集的质量,研究团队将 DALL・E 3 中的重字幕(re-captioning)技能应用于视频

  • 具体来说,研究团队起首训练一个高度描述性的字幕生成器模子,然后利用它为训练集中所有视频生成文本字幕
  • 与DALLE 3类似,研究团队还利用 GPT 将用户简短的prompt 转换为较长的详细字幕,然后发送给视频模子(Similar to DALL·E 3, we also leverage GPT to turn short user prompts into longer detailed captions that are sent to the video model),这使得 Sora 能够生成正确遵循详细字幕或详细prompt 的高质量视频
   关于DALLE 3的重字幕技能更具体的细节请见此文2.3节《AI绘画原明白析:从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion》
  2.3 DALLE 3:Improving Image Generation with Better Captions
  2.3.1 为提高文本图像配对数据集的质量:基于谷歌的CoCa​微调出图像字幕生成器
  2.3.1.1 什么是谷歌的CoCa
  2.1.1.2 分别通过短caption、长caption微调预训练好的image captioner
  2.1.1.3 为提高合成caption对文生图模子的性能:采用描述详细的长caption,训练的混淆比例高达95%..
  1.2.2 类似VDT或Google的W.A.L.T工作:引入auto regressive进行视频预测或扩展

其次,如之前所述,为了包管视频的一致性,模子层不是通过多个stage方式来进行预测,而是团体预测了整个视频的latent(即去噪时非先去噪几帧,再去掉几帧,而是一次性去掉全部帧的噪声)
但在视频内容的扩展上,比如从一段已有的视频向后拓展出新视频的训练过程中可能引入了auto regressive的task,以资助模子更好的进行视频特性和帧间关系的学习
更多可以参考下文Google的W.A.L.T工作,或下文“2.3.2 VDT的视频预测方案:把视频前几帧作为条件帧自回归预测下一帧”
1.3 对真实物理世界的模拟能力

1.3.1 sora学习了大量关于3D几何的知识

OpenAI 发现,视频模子在颠末大规模训练后,会表现出许多风趣的新能力。这些能力使 Sora 能够模拟物理世界中的人、动物和环境的某些方面。这些特性的出现没有任何明白的三维、物体等归纳弊端 — 它们纯粹是规模征象

  • 三维一致性(下图左侧)
    Sora 可以生成动态摄像机活动的视频。随着摄像机的移动和旋转,人物和场景元素在三维空间中的移动是一致的
    针对这点,sora一作Tim Brooks说道,sora学习了大量关于3D几何的知识,但是我们并没有事先设定这些,它完全是从大量数据中学习到的

    长序列连贯性和目标持久性(上图右侧)
    视频生成系统面对的一个庞大挑衅是在对长视频进行采样时保持时间一致性
    比方,即使人、动物和物体被遮挡或脱离画面,Sora 模子也能保持它们的存在。同样,它还能在单个样本中生成同一角色的多个镜头,并在整个视频中保持其外观
  • 与世界互动(下图左侧)
    Sora 有时可以模拟以简朴方式影响世界状态的动作。比方,画家可以在画布上留下新的笔触,这些笔触会随着时间的推移而连续,而视频中一个人咬一口面包 则面包上会有一个被咬的缺口

    模拟数字世界(上图右侧)
    视频游戏就是一个例子。Sora 可以通过根本策略同时控制 Minecraft 中的玩家,同时高保真地呈现世界及其动态。只需在 Sora 的提示字幕中提及 「Minecraft」,就能零样本激发这些功能
1.3.2 sora真的会模拟真实物理世界了么

对于“sora真的会模拟真实物理世界”这个问题,网上的解读非常多,很多人说sora是通向通用AGI的必经之路、不但是一个视频生成,更是模拟真实物理世界的模拟器,这个事 我个人觉得从技能的客观角度去探讨更符合,那样会让咱们的思维、认知更岑寂,而非人云亦云、终极不知所云
起首,作为“物理世界的模拟器”,需要能够在虚拟环境中重现物理现实,为用户提供一个逼真且不违反「物理规律」的数字世界
比如苹果不能突然在空中漂泊,这不符合牛顿的万有引力定律;比如在光线照射下,物体产生的阴影和高光的分布要符合光影规律等;比如物体之间产生碰撞后会破裂或者弹开
其次,李志飞等人在《为什么说 Sora 是世界的模拟器?》一文中提到,技能上至少有两种方式可以实现这样的模拟器


  • 一种是通过大数据学习出一个AI系统来模拟这个世界,比如说本文讨论的Sora能get到:“树叶在溪流中顺流而下”这句话所对应的物体活动轨迹是什么,更何况sora训练时还有LLM的夹持(别忘了上文1.2.1节中说的:与DALLE 3类似,研究团队还利用 GPT 将用户简短的prompt 转换为较长的详细字幕,然后发送给视频模子)
    比如在大量的文本-视频数据集中,GPT给一个视频写的更丰富的文本描述是:“路面积水反射出大楼的倒影”,而Sora遵循文本能力强,那Sora就能固定或机械的影象住该物理定律,但实在这个物理规则来自于GPT写的Prompt
  • 另外一种是弄懂物理世界各种征象背后的数学原理,并把这些原理手工编码到盘算机步伐里,从而让盘算机步伐“渲染”出物理世界需要的各种人、物、场景、以及他们之间的互动
   虚幻引擎(Unreal Engine,UE)就是这种物理世界的模拟器
  

  • 它内置了光照、碰撞、动画、刚体、材质、音频、光电等各种数学模子。一个开发者只需要提供人、物、场景、交互、剧情等配置,系统就能做出一个交互式的游戏,这种交互式的游戏可以看成是一个交互式的动态视频
  • UE 这类渲染引擎所创造的游戏世界已经能够在某种程度上模拟物理世界,只不外它是通过人工数学建模及渲染而成,而非通过模子从数据中自我学习。而且,它也没有和语言代表的认知模子连接起来,因此本质上缺乏世界常识。而 Sora 代表的AI系统有可能克制这些缺陷和范围
  不同于 UE 这一类渲染引擎,Sora 并没有显式地对物理规律背后的数学公式去“硬编码”,而是通过对互联网上的海量视频数据进行自监督学习,从而能够在给定一段笔墨描述的条件下生成不违反物理世界规律的长视频
与 UE 这一类“硬编码”的物理渲染引擎不同,Sora 视频创作的想象力来自于它端到端的数据驱动,以及跟LLM这类认知模子的无缝结合(比如ChatGPT已经确定了根本的物理常识)
最后值得一提的是,Sora 的训练可能用了 UE 合成的数据,但 Sora 模子自己应该没有调用 UE 的能力


第二部门 Sora相近技能的发展史:ViViT、DiT、VDT、NaViT、MAGVIT v2、W.A.L.T、VideoPoet

留意,和sora相关的技能实在有非常多,但有些技能在本博客之前的文章中写过了(详见本文开头),则本部门不再重复,比如DDPM、ViT、DALLE三代、Stable Diffusion(包罗潜伏空间LDM)等等
2.1 视频Transformer之ViViT:视频元素token化且时空编码(没加扩散过程、没带文本条件融合)

在具体介绍ViViT之前,先说三个在其之前的工作

  • 业界最早是用卷积那一套处理视频,比如时空3D CNN(Learning spatiotemporal features with 3d convolutional networks),由于3D CNN比图像卷积网络需要较多的盘算量,许多架构在空间和时间维度上进行卷积的因式分解和/或利用分组卷积,且近来,还通过在后续层中引入自留意力来增强模子,以更好地捕捉长程依赖性
  • 而Transformer在NLP范畴大获乐成,很快便出现了将Transformer架构应用到图像范畴的ViT(Vision Transformer),ViT将图片按给定大小分为不重叠的patches,再将每个patch线性映射为一个token,随位置编码和cls token(可选)一起输入到Transformer的编码器中(下图来自萝卜社长,如果不认识或忘了ViT的,详见此文的第4部门)

  • 2021年的这两篇论文《Is space-time attention all you need for video understanding?》、《Video transformer network》都是基于transformer做视频明白
而Google于2021年5月提出的ViViT(其对应论文为:ViViT: A Video Vision Transformer)便要尝试在视频中利用ViT模子,且他们充分借鉴了之前3D CNN因式分解等工作,比如考虑到视频作为输入会产生大量的时空token,处理时必须考虑这些长范围token序列的上下文关系,同时要兼顾模子效率问题
故作者团队在空间和时间维度上分别对Transformer编码器各组件进行分解,在ViT模子的底子上提出了三种用于视频分类的纯Transformer模子,如下图所示

区别于常规的二维图像数据,视频数据相当于需在三维空间内进行采样(拓展了一个时间维度),有两种方法来将视频
映射到token序列
(说白了,就是从视频中提取token,而后添加位置编码并对token进行reshape得到终极Transformer的输入
)


  • 第一种,如下图所示,将输入视频分别为token的直接方法是从输入视频剪辑中均匀采样 
     个帧,利用与ViT 相同的方法独立地嵌入每个2D帧(embed each 2D frame independently using the same method as ViT),并将所有这些token连接在一起

    具体地说,如果从每个帧中提取 
     个非重叠图像块(就像 ViT 一样),那么总共将有 
     个token通过transformer编码器进行传递,这个过程可以被看作是简朴地构建一个大的2D图像,以便按照ViT的方式进行tokenised(这点和本节开头所提到的21年那篇论文space-time attention for video所用的方式一致)
  • 第二种则是把输入的视频分别成多少个tuplet(类似不重叠的带空间-时间维度的立方体)
    每个tuplet会变成一个token(因这个tublelt的维度就是: t * h * w,故token包含了时间、宽、高)
    颠末spatial temperal attention进行空间和时间建模获得有效的视频表征token

2.1.1 spatio-temporal attention

上文说过,Google在ViT模子的底子上提出了三种用于视频分类的纯Transformer模子,接下来,介绍下这三种模子
固然,由于论文中把一个没有啥技巧且盘算复杂度高的模子作为模子1:简朴地将从视频中提取的所有时空token,然后每个transformer层都对所有配对进行建模,类似Neimark_Video_Transformer_Network_ICCVW_2021_paper的工作(其证明了VTN可以高效地处理非常长的视频)

以是下述三个模子在论文中被分别称之为模子2、3、4
2.1.2 factorised encoder及其代码实现

第二个模子如下图所示,该模子由两个串联的transformer编码器组成:


  • 第一个模子是空间编码器Spatial Transformer Encoder
    处理来自相同时间索引的token之间的相互作用(相当于处理同一帧画面下的各个元素,时间维度都相同了天然时间层面上没啥要处理的了,只处理空间维度),以产生每个时间索引的潜伏表示,并输出cls_token
  • 第二个transformer编码器是时间编码器Temporal Transformer Encoder
    处理时间步之间的相互作用(相当于处理不同帧,即空间维度相同但时间维度不同)。 因此,它对应于空间和时间信息的“后期融合”
    换言之,将输出的cls_token和帧维度的表征token拼接输入到时间编码器中得到终极的结果
对应的代码如下(为方便各人一目了然,我不仅给每一行代码都加上了注释,且把代码分解成了8块,每一块代码的重点都做了过细阐明)

  • 起首定义ViViT类,且定义相关变量
    1. # 定义ViViT模型类
    2. class ViViT(nn.Module):
    3.     def __init__(self, image_size, patch_size, num_classes, num_frames, dim=192, depth=4, heads=3, pool='cls', in_channels=3, dim_head=64, dropout=0.,
    4.                  emb_dropout=0., scale_dim=4):
    5.         super().__init__()  # 调用父类的构造函数
    6.         
    7.         # 检查pool参数是否有效
    8.         assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    9.         # 确保图像尺寸能被patch尺寸整除
    10.         assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
    11.         # 计算patch数量
    12.         num_patches = (image_size // patch_size) ** 2
    13.         # 计算每个patch的维度
    14.         patch_dim = in_channels * patch_size ** 2
    15.         # 将图像切分成patch并进行线性变换的模块
    16.         self.to_patch_embedding = nn.Sequential(
    17.             Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
    18.             nn.Linear(patch_dim, dim),
    19.         )
    复制代码
    为方便各人明白,我得表明一下上面中这行的寄义:
    1. Rearrange('b t c (h p1) (w p2) -> b t (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
    复制代码
    且为方便各人和我之前介绍ViT的文章前后连贯起来,故还是用的ViT那篇文章中的例子(此文的第4部门)
    以ViT_base_patch16为例,一张224 x 224的图片先分割成 16 x 16 的 patch ,很显然会因此而存在 
     个 patch
    且图片的长宽由原来的224  x 224 变成:14  x 14(因为224/16 = 14)   
    16*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*1616*16
    16*16
    16*16
    16*16
    ...
    以是对于上面那行意味着可以让批次大小b=1、时间维度t=2、RGB图像的通道数c=3
    原始维度即为:
    (1, 2, 3, 旧的长 = 224 patch_size = 16, 旧的宽 = 224 patch_size = 16),Rearrange之后的维度则变为:
    (12, 新的长14 x 新的宽14 = 196, 16 x 16 x 3 = 768)
  • 初始化位置编码和cls token
    self.pos_embedding 的维度为(1, num_frames, num_patches + 1, dim)
    在这里,num_frames 是 t,num_patches 是 n=196,dim 是 768,因此 pos_embedding 维度为 (1,2,197,768)
    1.         # 位置编码
    2.         self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, num_patches + 1, dim))
    3.         # 空间维度的cls token
    4.         self.space_token = nn.Parameter(torch.randn(1, 1, dim))
    5.         # 空间变换器
    6.         self.space_transformer = Transformer(dim, depth, heads, dim_head, dim * scale_dim, dropout)
    7.         # 时间维度的cls token
    8.         self.temporal_token = nn.Parameter(torch.randn(1, 1, dim))
    9.         # 时间变换器
    10.         self.temporal_transformer = Transformer(dim, depth, heads, dim_head, dim * scale_dim, dropout)
    11.         # dropout层
    12.         self.dropout = nn.Dropout(emb_dropout)
    13.         # 池化方式
    14.         self.pool = pool
    15.         # 最后的全连接层,用于分类
    16.         self.mlp_head = nn.Sequential(
    17.             nn.LayerNorm(dim),
    18.             nn.Linear(dim, num_classes)
    19.         )
    复制代码
  • patch嵌入和cls token的拼接
    输入数据 x 的维度在颠末嵌入层后变为 (1,2,196,768)
    self.space_token 的初始维度为 (1,1,768),被复制扩展成 (1,2,1,768) 以匹配批次和时间维度
    cls_space_tokens 和 x 在patch维度上拼接后,维度变为 (1,2,197,768)
    为何拼接之后成197了呢?原因很简朴,如ViT那篇文章中所述:“[class] token的维度为 [1, 768] ,通过Concat操作,[196, 768]  与 [1, 768] 拼接得到 [197, 768]”
    1.     def forward(self, x):
    2.         # 将输入数据x转换为patch embeddings
    3.         x = self.to_patch_embedding(x)
    4.         b, t, n, _ = x.shape  # 获取batch size, 时间维度, patch数量
    5.         # 在每个空间位置加上cls token
    6.         cls_space_tokens = repeat(self.space_token, '() n d -> b t n d', b=b, t=t)
    7.         x = torch.cat((cls_space_tokens, x), dim=2)  # 在维度2上进行拼接
    复制代码
  • 添加位置编码和应用dropout
    加上位置编码后,x 保持 (1,2,197,768) 维度不变。应用dropout后,x 的维度仍然不变
    1.         x += self.pos_embedding[:, :, :(n + 1)]  # 加上位置编码
    2.         x = self.dropout(x)  # 应用dropout
    复制代码
  • 空间Transformer
    重排 x 的维度为 (2,197,768),因为 b×t=1×2=2
    空间Transformer处理后,x 的维度变为 (2,197,768)
    1.         # 将(b, t, n, d)重排为((b t), n, d),为了应用空间变换器
    2.         x = rearrange(x, 'b t n d -> (b t) n d')
    3.         x = self.space_transformer(x)  # 应用空间变换器
    4.         x = rearrange(x[:, 0], '(b t) ... -> b t ...', b=b)  # 把输出重排回(b, t, ...)
    复制代码
  • 时间Transformer
    self.temporal_token 的初始维度为(1,1,768),被复制扩展成 (1,2,768)
    cls_temporal_tokens 和 x 在时间维度上拼接后,维度变为(1,3,768)
    1.      # 在每个时间位置加上cls token
    2.         cls_temporal_tokens = repeat(self.temporal_token, '() n d -> b n d', b=b)
    3.         x = torch.cat((cls_temporal_tokens, x), dim=1)  # 在维度1上进行拼接
    4.         x = self.temporal_transformer(x)  # 应用时间变换器
    复制代码
  • 池化
    如果 self.pool 是 'mean',则对 x 在时间维度上取均值,结果维度变为 (1,768)
    如果不是 'mean',则直接取 x 的第一个时间维度的cls token,结果维度同样是 (1,768)
    1.         # 根据pool参数选择池化方式
    2.         x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
    复制代码
  • 分类头
    self.mlp_head,将 (1,768) 维度的 x 转换为终极的分类结果,其维度取决于类别数num_classes,如果 num_classes 是 10,则终极输出维度为 (1,10)
    1.         # 通过全连接层输出最终的分类结果
    2.         return self.mlp_head(x)
    复制代码
2.1.3 factorised self-attention

第二个模子如下图所示,会先盘算空间自留意力(token中有相同的时间索引,相当于同一帧画面上的token元素),再盘算时间的自留意力(token中有相同的空间索引,相当于不同帧下同一空间位置的token,比如一直在视频的左上角那一块的token块)


  • 具体进行空间留意力交互的方法是:将初始视频序列生成的
    ,通过tensor的reshape思想映射为
    ,而后盘算得到空间自留意力结果
  • 同理,在时间维度上映射得到
    ,从而进行时间自留意力的盘算

2.1.4 factorised dot-product attention

由于实行表明空间-时间自留意力或时间-空间自留意力的顺序并不重要,以是第三个模子的结构如下图所示,一半的头仅在空间轴上盘算点积留意力,另一半头则仅在时间轴上盘算,且其参数数量增长了,因为有一个额外的自留意力层

不外,该模子通过利用dot-product点积留意力操作来取代因式分解factorisation操作,通过留意力盘算的方式来取代简朴的张量reshape。思想是对于空间留意力和时间留意力分别构建对应的键、值,如下图所示(图源自萝卜社长)

2.2 DiT:将扩散过程中的U-Net 换成 Transformer(2D图像生成,带文本条件融合)

2.2.1 DiT = VAE encoder + ViT + DDPM + VAE decoder

在ViT之前,图像范畴根本是CNN的天下,包罗扩散过程中的噪声估计器所用的U-net也是卷积架构,但随着ViT的横空出世,人们天然而然开始考虑这个噪声估计器可否用Transform架构来取代
2022年年底,William Peebles(当时在UC Berkeley,Peebles在

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

花瓣小跑

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表