【论文复现】Vision Transformer(ViT)

打印 上一主题 下一主题

主题 474|帖子 474|积分 1422

1. Transformer结构


1.1 编码器息争码器

翻译这个过程需要中央体。也就是说,编码,解码之间需要一个中介,英文先编码成一个意思,再解码成中文。
那么查字典这个过程就是编码息争码的表现。首先我们的大脑会把它编码,编码这个句子的意思,然后通过字典映射解码。但是如许的过程太过于繁琐,假如让机器做,超长文本就对应着超长的数据量,也倒霉于机器学习的上下文理解。那么就有了Attention留意力机制。
1.2 Attention:留意力机制

Attention机制的核心思想是,要想翻译一个句子并不需要完全编码。像我们人类一样,仅凭借几个词就可以猜出整句话的大概意义,即使我们不懂日语,也可以根据一些汉字推出来大概的意思,这是准确度低的情况;而“中译中”这种情况,准确度固然就更高一些。
Attention留意力机制:
Attention示意图本质上是加权平均。假如我非常留意某个地方,我想要多看,那就分配更高的权重。
计算权重是利用相似度计算,
Attention机制的优点

Attention的优点是能够实现并行计算和全局视野。
并行计算的大概性是由于它不像RNN一样,依赖时序数据。它只是加权计算,但并不需要像时序数据那样,依赖像是队列一样的进出顺序。
全局视野是由于在加权计算的时间,这个计算就是涉及了团体的,它一看就能看到全部。
1.3 Self Attention

对于Self Attention来说,它的输入是一个序列,序列的获得依靠的是vector。我们把一个词转换为序列模块,需要用到vector向量去指向。而vector的指向,是有空间性的,好比说两个意思很相近大概同义的词汇,它们在空间中的距离就会比较小。相反,意思差的多,距离固然就远。
如许也可以理解为vector是和词语的意义有关系的。
留意力分配的多少取决于公式:
                                    A                         t                         t                         e                         n                         t                         i                         o                         n                         (                         Q                         ,                         K                         ,                         V                         )                         =                         s                         o                         f                         t                         m                         a                         x                         (                                              Q                                           K                                  T                                                                        d                                  k                                                       )                         V                              Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V                  Attention(Q,K,V)=softmax(dk​                    ​QKT​)V
此中,Q代表Queries(输入的信息),K代表Keys(内容信息),V代表Values(信息本身,只表达了信息的特征)。

Q,K,V的获得

本质上,是input的线性变更。计算利用的是矩阵乘法,实现方法是nn.Linear。
用点积表示相似度的方法是由于cos角投影长度可以很直观地理解两个量的相似度。
在1.3的式子里除以                                                        d                               k                                                 \sqrt{dk}                  dk            ​这一步看起来很多余,但是它是为了避免较大的数值。较大的数值会导致softmax之后值更极端,softmax的极端值会导致梯度消失。这一步相称于控制了数值范围,让它在可观测范围内。
为什么是                                                        d                               k                                                 \sqrt{dk}                  dk            ​?假设q,k是均值为0,方差为1的尺度正态分布的独立随机变量,那么它们的点积的均值和方差分别为0以及dk。
之后做逐元素相乘(enterwise)。
我们需要将单词意思转换为句中意思,这就涉及一词多义的问题,在逐元素相乘得到的sum之后的z就涉及这个问题。那么我们需要根据句子中其他词来推理当前词的意思。
就好比说Mine,它有两种意思,一种是“我的”,一种是“矿石”。这固然是大相径庭的词性和词义。假设我们完全不知道这个多义词,但我们可以通过观察它们在句子中的位置和与上下文的接洽来推理这是什么意思。
1.4 Multihead Attention

这一步的核心是复读机()
这一步就是有多个W_q,W_k,W_v,那么上述操作重复多次,将结果用concat串在一起。
如许的复读机机制就是给留意力提供多种大概性。
应用了multihead的Conditional DETR就发现不同的head会将留意力放到物体的不同边上。

1.5 输入端适配

直接把图片切分成patches,flatten操作拉平patches,然后过一个linear projection使patches维度变小,然后编号123456789…输入网络即可。
就是切蛋糕喂给encoder和decoder。

这块儿有个patch 0的原因,有一种说法是从NLP来的:为了保持团体结构,变更尽大概的少。而NLP需要一些token负责输出,需要“终止输出”的功能模块。另一种CV里的说法是整合信息,设置在1-9之外就保持了1-9本身无干扰。Patch 0本质上是dynamic pooling layer。
1.6 位置编码

图像切分重排后失去了位置信息,而transformer的内部运算与空间信息无关。如许一来,就需要把位置信息编码重新传进网络。ViT利用了一个可学习的vector来编码,维度和patch维度一样,所以编码vector和patch vector直接相加构成输入。本质是相加,而相加是concat的一种特例。
1.7 ViT结构的数据流

输入图像是256x256像素巨细,然后切开,切成N(8x8=64)个小块,每一块则是256/8=32单元长(宽)度。也就是说,现在每一小块儿是32x32。把切开的每一小块都拉平,RGB值为3,每一块儿的维度就是3x32x32=3072维。但是3072维太高了,所以过一个linear projection把维度变成1024。但是此时每个小块儿的空间位置丢失了。所以需要加上position embedding这个可学习的向量,维度一样也是1024,让他们相加。position embedding放在patch0这里,一起进入transformer。
进入了transformer encoder之后,首先由于多了一个patch 0,Patches的表示向量里数量取N+1,即(b,65,1024)。在这个norm层里patches会被归一化,一直查验维度,保证维度是一样的。
末了到MLP Head手里,就只输入负责整合信息的patch 0,此时它表示为(b,1,1024)。如许就可以做分类任务了。

1.8 训练方法

Transformer 非常吃数据量,需要大量的样本,大规模利用Pre-Train。它先在大数据集(ImageNet)上预训练,然后到小数据集上做 Fine Tuning。
迁徙过去之后,需要把原本的MLP Head换掉,换成对应类别数的FC层。处置惩罚不同尺寸输入的时间需要对Postional Encoding的结果举行插值。
插值方法:图片切好了之后,编号,但不同的input size和patch size会切出不同数量的patch,position embedding也会变。所以编号的方法需要缩放。
1.9 实行结果

Transformer的性能需要巨大数据量的保证,很吃资源。否则,无法充分的发挥出它的性能。它和ResNet的性能不相上下
Attention的距离可以等价为Conv的感受野巨细。越深的层数,Attention超过的距离越远。在最底层,也有head能覆盖到很远的距离。这阐明它确着实捕捉信息,做信息整合。
模子留意力集中的地方,都和分类的语义高度相关。
2. 代码复现

   VIT仓库链接:https://github.com/lucidrains/vit-pytorch
  Usage:
  1. import torch
  2. from vit_pytorch import ViT # 抽象出了一个VIT类
  3. v = ViT(
  4.     image_size = 256,# 图片像素大小
  5.     patch_size = 32, # patch的大小
  6.     num_classes = 1000, # 分类数量
  7.     dim = 1024,# 维度
  8.     depth = 6, # transformer的block数量
  9.     heads = 16, # 线性变换后输出张量的最后维数,多头注意力层中的头数
  10.     mlp_dim = 2048, # MLP前馈层维度
  11.     dropout = 0.1, # 每个训练步骤中被关闭神经元的比例,可以调成0
  12.     emb_dropout = 0.1 # 嵌入丢失率
  13. )
  14. img = torch.randn(1, 3, 256, 256)
  15. preds = v(img) # (1, 1000)
复制代码
2.1 切图重排

输入端适配涉及到切图和reshape。
以下是ViT部分.py代码:
  1. import torch
  2. from torch import nn
  3. from einops import rearrange, repeat
  4. from einops.layers.torch import Rearrange
  5. # helpers
  6. def pair(t):
  7.     return t if isinstance(t, tuple) else (t, t)
  8. # classes
  9. class FeedForward(nn.Module):
  10.     def __init__(self, dim, hidden_dim, dropout = 0.):
  11.         super().__init__()
  12.         self.net = nn.Sequential(
  13.             nn.LayerNorm(dim),
  14.             nn.Linear(dim, hidden_dim),
  15.             nn.GELU(),
  16.             nn.Dropout(dropout),
  17.             nn.Linear(hidden_dim, dim),
  18.             nn.Dropout(dropout)
  19.         )
  20.     def forward(self, x):
  21.         return self.net(x)
  22. class Attention(nn.Module):
  23.     def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
  24.         super().__init__()
  25.         inner_dim = dim_head *  heads
  26.         project_out = not (heads == 1 and dim_head == dim)
  27.         self.heads = heads
  28.         self.scale = dim_head ** -0.5
  29.         self.norm = nn.LayerNorm(dim)
  30.         self.attend = nn.Softmax(dim = -1)
  31.         self.dropout = nn.Dropout(dropout)
  32.         self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  33.         self.to_out = nn.Sequential(
  34.             nn.Linear(inner_dim, dim),
  35.             nn.Dropout(dropout)
  36.         ) if project_out else nn.Identity()
  37.     def forward(self, x):
  38.         x = self.norm(x)
  39.         qkv = self.to_qkv(x).chunk(3, dim = -1)
  40.         q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
  41.         dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  42.         attn = self.attend(dots)
  43.         attn = self.dropout(attn)
  44.         out = torch.matmul(attn, v)
  45.         out = rearrange(out, 'b h n d -> b n (h d)')
  46.         return self.to_out(out)
  47. class Transformer(nn.Module):
  48.     def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
  49.         super().__init__()
  50.         self.norm = nn.LayerNorm(dim)
  51.         self.layers = nn.ModuleList([])
  52.         for _ in range(depth):
  53.             self.layers.append(nn.ModuleList([
  54.                 Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
  55.                 FeedForward(dim, mlp_dim, dropout = dropout)
  56.             ]))
  57.     def forward(self, x):
  58.         for attn, ff in self.layers:
  59.             x = attn(x) + x
  60.             x = ff(x) + x
  61.         return self.norm(x)
  62. class ViT(nn.Module):
  63.     def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
  64.         super().__init__()
  65.         image_height, image_width = pair(image_size)
  66.         patch_height, patch_width = pair(patch_size)
  67.         assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
  68.         num_patches = (image_height // patch_height) * (image_width // patch_width)
  69.         patch_dim = channels * patch_height * patch_width
  70.         assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
  71.         self.to_patch_embedding = nn.Sequential(
  72.             Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
  73.             nn.LayerNorm(patch_dim),
  74.             nn.Linear(patch_dim, dim),
  75.             nn.LayerNorm(dim),
  76.         )
  77.         self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
  78.         self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
  79.         self.dropout = nn.Dropout(emb_dropout)
  80.         self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
  81.         self.pool = pool
  82.         self.to_latent = nn.Identity()
  83.         self.mlp_head = nn.Linear(dim, num_classes)
  84. # 解析代码段
  85.     def forward(self, img):
  86.         x = self.to_patch_embedding(img)
  87.         b, n, _ = x.shape
  88.         cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
  89.         x = torch.cat((cls_tokens, x), dim=1)
  90.         x += self.pos_embedding[:, :(n + 1)]
  91.         x = self.dropout(x)
  92.         x = self.transformer(x)
  93.         x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
  94.         x = self.to_latent(x)
  95.         return self.mlp_head(x)
复制代码
transpose函数的作用:重排张量(tensor)大概数组的维度。有一个形状为(batch_size, channels, height, width)的四维张量,代表一批图像数据。那就可以将将channels维度移动到最前面,即形状变为(channels, batch_size, height, width)。这时,你就可以利用transpose操作来实现这一转换。(类似于矩阵行变更)
  1. img = torch.randn(1,3,256,256)
  2. b=1
  3. c=3
  4. h=256 = h*p1,h = 8
  5. w =256
  6. self.to_patch.embedding = nn.Sequential(
  7.         Rearrange(‘b c (h p1)(w p2)-> b(h w)(p1 p2 c)’,p1 = patch_height, p2 = patch_width),# 图片切分重排
  8.         nn.Linear(patch_dim, dim) # Linear Projection of Flattened Patches
  9.         )
复制代码
2.2 构造Patch 0

这一步:
  1. cls_tokens = repeat(self.cls_token, '() n d -> b n d',b = b)
  2.                 x = torch.cat((cls_tokens, x), dim = 1) # concat方法,维度为1,在n的维度上
复制代码
2.3 positional embedding

  1. self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
  2. # 位置编码:它是一个可学习的参数,初始化为随机值。
  3. x += self.pos_embedding[:,:(n+1)]
  4. # 将位置编码加到输入序列上。
复制代码
2.4 代码示例

首先准备数据集。
  1. from __future__ import print_function
  2. import glob
  3. from itertools import chain
  4. import os
  5. import random
  6. import zipfile
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import pandas as pd
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.optim as optim
  14. from linformer import Linformer
  15. from PIL import Image
  16. from sklearn.model_selection import train_test_split
  17. from torch.optim.lr_scheduler import StepLR
  18. from torch.utils.data import DataLoader, DataSet
  19. from torchvision import datasets, transforms
  20. from tqdm.notebook import tqdm
  21. from vit_pytorch import ViT
复制代码
  1. batch_size = 64
  2. epochs = 20
  3. lr = 3e-5
  4. gamma = 0.7
  5. seed = 42
复制代码
  1. def seed_everything(seed):
  2.         random.seed(seed)
  3.         os.environ['PYTHONHASHSEED'] = str(seed)
  4.         np.random.seed(seed)
  5.         torch.manual_seed(seed)
  6.         torch.cuda.manual_seed(seed)
  7.         torch.cuda.manual_seed_all(seed)
  8.         torch.backends.cudnn.deterministic = True
  9. seed_everything(seed)
  10. device = 'cuda'
复制代码
  1. os.makedirs('data',exist_ok = True)
  2. train_dir = 'data/train'
  3. test_dir = 'data/test'
  4. with zipfile.ZipFile('data/train.zip') as train_zip:
  5.         train_zip.extractall('data')
  6. with zipfile.ZipFile('data/test.zip') as test_zip:
  7.         test_zip.extractall('data')
  8. train_list = glob.glob(os.path.join(train_dir,'*.jpg')) # 查找匹配的jpg文件
  9. test_list = glob.glob(os.path.join(test_dir,'*.jpg'))
  10. print(f'Train Data:{len(train_list)}')
  11. print(f'Test Data:{len(test_list)}')
  12. labels = [path.split('/')[-1].split('\\')[-1].split[0] for path in train_list]
  13. print(train_list[0]
  14. print(labels[0]))
复制代码
  1. random_idx = np.random.randint(1,len(train_list),size = 9)
  2. fig, axes = plt.subplots(3,3,figsize = (16,12))
  3. for idx,ax in enumerate(axes.ravel()):
  4.         img = Imag.open(train_list[idx])
  5.         ax.set_title(labels[idx])
  6.         ax.imshow(img)
复制代码
  1. train_list, valid_list = train_test_split(train_list, test_size = 0.2, stratify = labels, random_state = seed)
  2. print(f'Train Data:{len(train_list)}')
  3. print(f'Validation Data:{len(valid_list)}')
  4. print(f'Test Data:{len(test_list)}')
复制代码
  1. train_tranforms = tranforms.Compose(
  2.         [
  3.                 transforms.Resize((224,224)),
  4.                 transforms.RandomResizedCrop(224),
  5.                 transforms.RandomHorizontalFlip(),
  6.                 transforms.ToTensor(),
  7.         ]
  8. )
  9. val_tranforms = tranforms.Compose(
  10.         [
  11.                 transforms.Resize((224,224)),
  12.                 transforms.RandomResizedCrop(224),
  13.                 transforms.RandomHorizontalFlip(),
  14.                 transforms.ToTensor(),
  15.         ]
  16. )
  17. test_tranforms = tranforms.Compose(
  18.         [
  19.                 transforms.Resize((224,224)),
  20.                 transforms.RandomResizedCrop(224),
  21.                 transforms.RandomHorizontalFlip(),
  22.                 transforms.ToTensor(),
  23.         ]
  24. )
复制代码
  1. class CatsDogsDataset(Dataset):
  2.         def __init__(self, file_list, transform = None):
  3.                 self.file_list = file_list
  4.                 self.transform = transform
  5.         def __len__(self):
  6.                 self.filelength = len(self.file_list)
  7.                 return self.filelength
  8.         def __getitem__(self,idx):
  9.                 img_path = self.file_list[idx]
  10.                 img = Image.open(img_path)
  11.                 img_transformed = self.transform(img)
  12.                 label = img_path.split('/')[-1].split("\")[-1].split("、")[0]
  13.                 label = 1 if label == "dog" else 0
  14.                 return img_transformed, label
复制代码
  1. train_data = CatsDogsDataset(train_list, transform = train_transforms)
  2. valid_data = CatsDogsDataset(valid_list, transform = val_transforms)
  3. test_data = CatsDogsDataset(test_list, transform = test_transforms)
复制代码
  1. train_loader = DataLoader(dataset = train_data,batch_size = batch_size,shuffle = True)
  2. valid_loader = DataLoader(dataset = valid_data,batch_size = batch_size,shuffle = True)
  3. test_loader = DataLoader(dataset = test_data,batch_size = batch_size,shuffle = True)
复制代码
  1. print(len(train_data), len(train_loader))
  2. print(len(valid_data), len(valid_loader))
复制代码
模子建立:
  1. model = ViT(
  2.         image_size = 224,
  3.         patch_size = 16,
  4.         num_classes = 2,
  5.         dim = 768,
  6.         depth = 12,
  7.         heads = 12,
  8.         mlp_dim = 3072,
  9.         dropout = 0.1,
  10.         emb_dropout = 0.1
  11. ).to(device) # 将Transformer模型移动到指定设备上,比如GPU
  12. model.load_state_dict(torch.load('vit_base_patch16_224_r.pth'),strict = False)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

曹旭辉

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表