CycleMLP:用于密集预测的类似 MLP 的架构

paper:https://arxiv.org/abs/2107.10224

作者单位:香港大学, 商汤科技
代码:https://github.com/ShoufaChen/CycleMLP

核心:用 Cycle-FC来替换Spatial FC(计算量大且网络对于不同图像分辨率的输入不可接受,且不能用于下游任务)

本文提出了一个简单的 MLP-like 的架构 CycleMLP,它是视觉识别和密集预测的通用主干,不同于现代 MLP 架构,例如 MLP-Mixer、ResMLP 和 gMLP,其架构与图像大小相关,因此是在目标检测和分割中不可行。

与现代方法相比,CycleMLP 有两个优势。

(1) 可以应对各种图像尺寸。

(2) 利用局部窗口实现对图像大小的线性计算复杂度。

相比之下,以前的 MLP 具有二次计算,因为它们具有完全的空间连接。

单个 CycleMLP Block 依然是分为 Token-mixing MLP 和 Channel mixing MLP,其中作者主要的贡献点在于替换 MLP-mixer 的 Token-mixing MLP 为 Cycle-FC。所以整个 CycleMLP Block 可以描述为:

何为 Cycle-FC ?要回答这个问题,我们首先来回顾一下 Channel FC 以及 Spatial FC.

Channel FC 即通道方向的映射,等效与1×1 卷积,其参数量与图像尺寸无关,而与通道数(token 维度)有关。假设输入输出特征图尺寸一致,则参数量为 C^2,其中 C 为通道数。而计算量则为 HWC^2,其中 H W 分别为特征图的高和宽。如果只考虑计算量与图像尺寸的影响的话,则为 O ( H W ) 。

Spatial FC 即 MLP-Mixer 使用的 Token-mixing 全连接层,在这里我们都是只考虑一个全连接层,则其实现的是 H W − > H W 的映射,参数量为 H^2W^2,计算量也为 H 2 W 2 C H^2W^2C,如果只考虑计算量与图像尺寸的影响的话,则为 O(H^2W^2)。并且HW 大小固定,网络对于不同图像分辨率的输入不可接受,且不能用于下游任务以及使用类似 EfficientNetV2 等的多分辨率训练策略。

为什么我们可以在复杂度分析时只考虑 H W 的影响呢?因为在金字塔结构的 MLP 中,通常一开始的 patch size 为 4,然后输入尺寸为 224×224,则一开始的 H = W = 56 = 224 / 4 ,而 C = 64 或者 96 ,所以C≪HW。如果对于下游任务而言,例如输入变为了512×512,则它们之间的差距更大了。为此在这里我们可以在复杂度分析中暂时只考虑 H W  而忽略 C 。

为了同时克服 Spatial 对于图像输入尺寸敏感以及计算量大的问题,作者提出了 Cycle-FC。其只是用通道方向的映射并且计算量和 Channel FC 保持一致。其说白了就是不断地以 [+1 0 -1 0 +1 0 -1 0 +1 …] 的方式移动特征图,将不同空间位置的特征对齐到同一个通道上,然后使用1×1 卷积。

回忆 AS-MLP,其采用的特征图移动方式则为 [+1 0 -1 +1 0 -1 +1 0 -1] 这样的成组方式,CycleMLP 则是使用“楼梯型”方式,但是其思想没有本质不同。此外,AS-MLP 确实对特征图进行了 Shift,并且采用了 zero-padding,而 CycleMLP 在具体实现过程中则是使用可变形卷积加以实现的。我个人对于 AS-MLP 与 CycleMLP 的理解如下图所示,可见他们其实核心思想是一致的。

from torchvision.ops.deform_conv import deform_conv2d

img3

CycleMLP 与 AS-MLP 只并行 H 和 W 方向的移动不同,CycleMLP 其实是三条支路并行:H 方向,W 方向,以及不移动特征图做通道方向映射。此外,AS-MLP 在一开始还做了一次 Channel Projection 进行降维。

img5

CycleMLP 最终使用的和 ViP 一样,使用 Split Attention 来融合三条支路

class CycleMLP(nn.Module):
    def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)

        self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0)
        self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0)

        self.reweight = Mlp(dim, dim // 4, dim * 3)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        c = self.mlp_c(x)

        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)

        x = h * a[0] + w * a[1] + c * a[2]

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

最后提一句,作者将投影区间定义为是 Pseudo-Kernel,这其实也是我们常说的 感受野 一词。

img4

2.2 整体网络结构

CycleMLP 的 Patch Embedding 也很有特色,使用卷积核大小为 7×7 ,步长为 4 的卷积。后续 Hire-MLP 其实也是这样进行的 Patch Embedding。相比而言 Swin 使用卷积核大小为 4×4,步长为 4 的卷积。在近期的我自己的小实验中也发现:Patch Embedding 时具有重叠会更好,这样可以避免边界效应并在小数据集上提升性能。CycleMLP 中间采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替重复使用 CycleMLP Block。下采样使用卷积核大小为3×3,步长为 2 的卷积,这样做也有重叠,Hire-MLP 也是这样子哈。最后经过全局池化后连接一个全连接分类器即可。作者一共提出来了四种配置:

请添加图片描述

在这四种配置,Si​ 指 Patch Embedding 中的 Patch size,Ci​ 指 Patch Embedding 的输出编码特征维度,E i ​ 为 Channel-mixing MLP 中两个全连接层中第一个全连接层的 expand radio,Li​ 则是不同 Stage 中 Block 的重复次数。

3. 下游任务实验

CycleMLP 旨在为 MLP 模型的目标检测、实例分割和语义分割提供一个有竞争力的基线。与 AS-MLP 不同之处在于,CycleMLP 在 ADE20K 上进行实验,而 AS-MLP 在 COCO 上进行的实验。这真的是巧合,还是故意避开?不敢问也不敢说。

目标检测性能表现:相比 PVT,CycleMLP 都更具有优势。

请添加图片描述

语义分割性能表现:特别是,CycleMLP 在 ADE20K val 上达到了 45.1 mIoU,与 Swin (45.2 mIOU) 相当。

请添加图片描述

4. 消融实验

作者一共进行了三组消融实验:

  • Cycle-FC VS Spacial-FC and Channel-FC: 作者将 CycleMLP 中的 Cycle-FC 替换为 Spacial-FC 或者 Channel-FC,结果发现 CycleMLP 具有更好的性能。但是只有 Channel-FC,也能达到 79.4% 的性能,真的这么高吗,比 ResNet 高那么多…
请添加图片描述
  • Cycle-FC 中三条支路的选择:Cycle-FC 中作者并行了三条支路,对他们的消融实验发现,同时拥有正交 H 和 W 方向效果很好,加上不动之后效果更好。两倍 H 方向或者两倍 W 方向比仅含有 H 或者 W 方向会好一些。
请添加图片描述
  • 测试分辨率的影响:最终发现测试正确率随分辨率先升后降,CycleMLP 表现最好。
请添加图片描述

4. 总结与反思

CycleMLP 提出了 Cycle-FC,即将不同 token 的特征对齐到同一个通道,然后使用通道映射,从而实现网络参数量计算量的降低,以及对图像分辨率不敏感。CycleMLP 也在下游任务上测试了自己的性能表现。整体而言做得还是很充分的。不过其试图造一些新的名词以强化贡献,例如 Cycle-FC 其实就是移动特征图,Pseudo-Kernel 其实就是卷积核感受野的概念。最终 CycleMLP 通过三条并行的支路构建了十字形感受野。相比 AS-MLP,CycleMLP 在感受野分析上略显不足,没有更泛化地分析以及进行消融实验。比如 CycleMLP 也可以间隔采样,例如 [+4 +2 0 -2 -4 -2 0 2 4 2 0 -2 …],就可以构建 AS-MLP 那种空洞的更大范围的感受野。(最后插一句:CycleMLP 和 AS-MLP,就像 ResMLP 与 MLP-Mixer,学术界的 Idea 真的能够这么惊人的一致吗?)

AS-MLP:首个检测与分割领域MLP架构

paper: https://arxiv.org/abs/2107.08391

github:https://github.com/svip-lab/AS-MLP

本文是上海科技大学在MLP架构方面的探索,它设计了一种轴向移位操作以便于进行空间信息交互。在架构方面,AS-MLP采用了类似PVT的分层架构,因为可以轻易的迁移到下游任务。所提方法在ImageNet数据集上取得了优于其他MLP架构的性能,在COC检测与ADE20K分割任务上取得了与Swin相当的性能。值得一提的是,AS-MLP是首个迁移到下游任务的MLP架构。注:CycleMLP与AS-MLP属于同一时期的工作,发到arxiv的时间也只差两天,说两者都是首个其实也可以。

本文提出了一种轴向移动架构AS-MLP(Axial Shifted MLP)用于不同的视觉任务(包含图像分类、检测以及分割)。不同于MLP-Mixer通过矩阵转置+词混叠MLP进行全局空域特征编码,我们在局部特征通信方向投入了更多的关注。

通过轴向移动特征信息,AS-MLP可以得到不同方向的信息流,这有助于捕获局部相关性。该操作使得我们采用纯MLP架构即可取得与CNN相同的感受野。我们还可以类似卷积核设置AS-MLP模块的感受野尺寸以及扩张因子。如此简单而有效的架构取得了优于其他MLP架构的性能,同时具有与Transformer架构(比如Swin Transformer)相当的性能,甚至具有稍少的FLOPs。比如,AS-MLP在ImageNet数据集上凭借88M参数量+15.2GFLOPs取得了83.3%top1精度,且无需额外训练数据。

此外,所提AS-MLP也是首个用于下游任务(如目标检测、语义分割)的MLP架构。AS-MLP在COC验证集上取得了51.5mAP指标,在ADE20K数据集上取得了49.5mIoU指标,具有与Transformer架构相当的性能。

Method:

Comparisons between AS-MLP, Convolution, Transformer and MLP-Mixer

在这里,我们将AS-MLP、卷积、Swin以及MLP-Mixer进行对比分析。尽管这些模型是从不同角度出发设计得到,但它们均基于给定输出位置点,其值依赖于局部特征的加权。这些采样位置包含局部依赖与长距离依赖。

从上述对比图可以看到:

  • 卷积是一种局部感受野的操作,更适合于提取具有局部依赖关系的特征;
  • Swin同样是一种局部感受野操作,Swin为自注意力机制引入了局部性提升了Transformer架构的性能,同时也降低了计算复杂度;
  • MLP-Mixer是一种全局感受野操作,它仅仅由矩阵转置与MLP操作构成;
  • AS-MLP是一种局部“十”字感受野操作,它可以更好的提取局部依赖关系。

Variants of AS-MLP Architecture

前面的Figure仅仅给出了Tiny版本的AS-MLP架构,参考DeiT与Swin,我们通过调整模块数与通道数构建了不同大小的模型。

image.png

Experiments

ImageNet Classification

image.png

上表给出了所提方法在ImageNet数据上的性能对比,从中可以看到:

  • 所提AS-MLP取得了比其他MLP架构更优的性能,同时具有相似的参数量与FLOPs;
  • AS-MLP-S取得了83.1%的top1精度同时具有比Mixer-B/16、ViP-Medium/7更少的参数量;
  • 此外,AS-MLP-B取得了与Swin相当的性能:83.3%。
image.png

此外,我们还对比了端侧配置版本的AS-MLP,结果见上表。可以看到:在端侧配置下,所提方法大幅超越了Swin Transformer。

COCO Detection

image.png

上表对比了COCO检测任务上的性能对比,可以看到:

  • 所提AS-MLP是首个用于下游任务的MLP架构;
  • 所提AS-MLP取得了与Swin相当的性能。具体来说,在Cascade Mask R-CNN+Swin-B取得了51.9AP指标,参数量为145M;而AS-MLP-B取得了51。5AP指标,参数量为145M。

ADE20K Segmentation

image.png

上表给出了ADE20K分割任务上的性能对比,从中可以看到:

  • 所提AS-MLP同样是首个用于分割任务的MLP架构;
  • AS-MLP-T取得了比Swin-T等有的性能,同时具有稍少FLOPs;
  • UperNet+Swin-B取得了49.7mIoU,参数量为121M,计算量为1188GFLOPs;而UperNet+AS-MLP-B取得了49.5mIoU,参数量121M,计算量为1166GFLOPs。

Ablation Study

AS-MLP的核心是轴向移动,接下来我们将对其不同成分进行消融分析,所有试验均基于AS-MLP-T实现。

image.png

上表对比了不同padding方式、不同移动尺寸以及不同扩展比例的性能对比,从中可以看到:

  • zero-padding更适合于AS-MLP设计;
  • 提升扩张因子会轻微降低模型性能;
  • 提升移动尺寸,模型精度会先上升后下降。
  • 基于上述分析,我们采用shift=5,zero-padding,dilation=1。

image.png
我们同时还比较了AS-MLP模块的不同链接类型,结果见上表,从中可以看到:在不同移动尺寸下,并行连接总是具有比串行连接更佳性能

Comparsion with S2MLP

在初看到该文时,第一感觉这个与百度的那篇S2MLP(见下图核心模块)真的非常相似,都是采用了垂直、水平移位方式进行空间信息交互,而且还都是上下左右四个方向。可惜AS-MLP并未与S2MLP进行对比,反而比较晚(指的是见刊arxiv)的ViP进行的对比。

image.png

既然提到了,我们还是对S2MLP与ASMLP进行一下对比吧。

  • 在整体架构方面,AS-MLP采用了类似PVT的分层架构,而S2MLP一文则是采用了类似ViT的柱状架构;
  • 在应用方面,AS-MLP即可应用于图像分类,还可以迁移到下游任务中;而S2MLP则仅适用于图像分类,并不适用下游任务;
  • 在核心模型方面,AS-MLP采用并行垂直、水平移动,分别进行特征汇聚后再进行特征相加汇聚;而S2MLP则采用分组方式,不同组进行不同方向的移动,然后再进行空间信息汇聚;
  • 在模型性能方面,AS-MLP取得了与Swin相当的性能,比ViP更优的性能;而S2MLP的性能则弱于Swin与ViP;
  • 最后一点,AS-MLP开源了,但S2MLP并未开源。

自然语言处理中注意力机制综述

注意力汇聚

查询(自主提示)和键(非自主提示)之间的交互形成了 注意力汇聚(attentionpooling)。注意力汇聚有选择地聚合了值(感官输入)以生成最终的输
出。注意力汇聚(attention pooling)公式:

其中 x 是查询,(xi; yi) 是键值对。注意力汇聚是 yi 的加权平均。将查询 x 和键 xi之间的关系建模为 注意力权重(attetnion weight) (x; xi),如 (10.2.4) 所示,这个权重将被分配给每一个对应值 yi。对于任何查询,模型在所有键值对上的注意力权重都是一个有效的概率分布:它们是非负数的,并且总和为1。

正如我们所看到的,选择不同的注意力评分函数 a 会导致不同的注意力汇聚操作。

1、加性注意力

2、缩放点积注意力

使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的⻓度d。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为 0 和方差为 1。那么两个向量的点积的均值为 0,方差为 d。为确保无论向量⻓度如何,点积的方差在不考虑向量⻓度的情况下仍然是 1,则可以使用 缩放点积注意力(scaled dot-product attention)评分函数:

1. 写在前面

近些年来,注意力机制一直频繁的出现在目之所及的文献或者博文中,可见在nlp中算得上是个相当流行的概念,事实也证明其在nlp领域散发出不小得作用。这几年得顶会paper就能看出这一点。本文深入浅出地介绍了自然语言处理中的注意力机制技术。据Lilian Weng博主总结以及一些资料显示,Attention机制最早应该是在视觉图像领域提出来的,这方面的工作应该很多,历史也比较悠久。人类的视觉注意力虽然存在很多不同的模型,但它们都基本上归结为给予需要重点关注的目标区域(注意力焦点)更重要的注意力,同时给予周围的图像低的注意力,然后随着时间的推移调整焦点。而直到Bahdanau等人发表了论文《Neural Machine Translation by Jointly Learning to Align and Translate》,该论文使用类似attention的机制在机器翻译任务上将翻译和对齐同时进行,这个工作目前是最被认可为是第一个提出attention机制应用到NLP领域中,值得一提的是,该论文2015年被ICLR录用,截至现在,谷歌引用量为5596,可见后续nlp在这一块的研究火爆程度。

注意力机制首先从人类直觉中得到,在nlp领域的机器翻译任务上首先取得不错的应用成功。简而言之,深度学习中的注意力可以广义地解释为重要性权重的向量:为了预测一个元素,例如句子中的单词,使用注意力向量来估计它与其他元素的相关程度有多强,并将其值的总和作为目标的近似值

既然注意力机制最早在nlp领域应用于机器翻译任务,那在这个之前又是怎么做的呢?传统的基于短语的翻译系统通过将源句分成多个块然后逐个词地翻译它们来完成它们的任务。 这导致了翻译输出的不流畅。想想我们人类是如何翻译的?我们首先会阅读整个待翻译的句子,然后结合上下文理解其含义,最后产生翻译。在某种程度上,神经机器翻译(NMT)的提出正是想去模仿这一过程。而在NMT的翻译模型中经典的做法是由编码器 – 解码器架构制定(encoder-decoder),用作encoder和decoder常用的是循环神经网络。这类模型大概过程是首先将源句子的输入序列送入到编码器中,提取最后隐藏的表示并用于初始化解码器的隐藏状态,然后一个接一个地生成目标单词,这个过程广义上可以理解为不断地将前一个时刻 t-1 的输出作为后一个时刻 t 的输入,循环解码,直到输出停止符为止。通过这种方式,NMT解决了传统的基于短语的方法中的局部翻译问题:它可以捕获语言中的长距离依赖性,并提供更流畅的翻译。但是这样做也存在很多缺点,譬如,RNN是健忘的,这意味着前面的信息在经过多个时间步骤传播后会被逐渐消弱乃至消失。其次,在解码期间没有进行对齐操作,因此在解码每个元素的过程中,焦点分散在整个序列中。对于前面那个问题,LSTM、GRU在一定程度能够缓解。而后者正是Bahdanau等人重视的问题。

 2、NLP中Attention mechanism的起源

在Seq2Seq结构中,encoder把所有的输入序列都编码成一个统一的语义向量context,然后再由decoder解码。而context自然也就成了限制模型性能的瓶颈。譬如机器翻译问题,当要翻译的句子较长时,一个context可能存不下那么多信息。除此之外,只用编码器的最后一个隐藏层状态,感觉上都不是很合理。实际上当我们翻译的时候譬如:Source:机器学习–>Target:machine learning。当decoder要生成”machine”的时候,应该更关注”机器”,而生成”learning”的时候,应该给予”学习”更大的权重。所以如果要改进Seq2Seq结构,一个不错的想法自然就是利用encoder所有隐藏层状态解决context限制问题。

Bahdanau等人把attention机制用到了神经网络机器翻译(NMT)上。传统的encoder-decoder模型通过encoder将Source序列编码到一个固定维度的中间语义向量context,然后在使用decoder进行解码翻译到目标语言序列。前面谈到了这种做法的局限性,而且,Bahdanau等人在摘要也说到这个context可能是提高这种基本编码器 – 解码器架构性能的瓶颈,那Bahdanau等人又是如何尝试缓解这个问题的呢?让我们来一探究竟,作者为了缓解中间向量context很难将Source序列所有必要信息压缩进来的问题,特别是对于那些很长的句子。提出在机器翻译任务上在 encoder–decoder 做出了如下扩展:将翻译和对齐联合学习。这个操作在生成Target序列的每个词时,用到的中间语义向量context是Source序列通过encoder的隐藏层的加权和,而传统的做法是只用encoder最后一个输出 ht 作为context,这样就能保证在解码不同词的时候,Source序列对现在解码词的贡献是不一样的。想想前面那个例子:”Source:机器学习–>Target:machine learning”(假如中文按照字切分)。decoder在解码”machine”时,”机”和”器”提供的权重要更大一些,同样,在解码”learning”时,”学”和”习”提供的权重相应的会更大一些,这在直觉也和人类翻译也是一致的。通过这种attention的设计,作者将Source序列的每个词(通过encoder的隐藏层输出)和Target序列(当前要翻译的词)的每个词巧妙的建立了联系。想一想,翻译每个词的时候,都有一个语义向量,而这个语义向量是Source序列每个词通过encoder之后的隐藏层的加权和。 由此可以得到一个Source序列和Target序列的对齐矩阵,通过可视化这个矩阵,可以看出在翻译一个词的时候,Source序列的每个词对当前要翻译词的重要性分布,这在直觉上也能给人一种可解释性的感觉。

3. NLP中的注意力机制

随着注意力机制的广泛应用,在某种程度上缓解了源序列和目标序列由于距离限制而难以建模依赖关系的问题。现在已经涌现出了一大批基于基本形式的注意力的不同变体来处理更复杂的任务。让我们一起来看看其在不同NLP问题中的注意力机制。

其实我们可能已经意识到了,对齐模型的设计不是唯一的,确实,在某种意义上说,根据不同的任务设计适应于特定任务的对齐模型可以看作设计出了新的attention变体,让我们再看看这个模型(函数): score(st,hi) 。再来看几个代表性的work。

Citation等人提出Content-base attention,其对齐函数模型设计为:

Bahdanau等人的Additive(*),其设计为:

Luong[4]等人文献包含了几种方式:

以及Luong等人还尝试过location-based function。这种方法的对齐分数仅从目标隐藏状态学习得到。

  • Vaswani[6]等人的Scaled Dot-Product(^)缩放点积注意:

细心的童鞋可能早就发现了这东东和点积注意力很像,只是加了个scale factor。当输入较大时,softmax函数可能具有极小的梯度,难以有效学习,所以作者加入比例因子 1/n 。

Cheng[7]等人的Self-Attention(&)可以关联相同输入序列的不同位置。 从理论上讲,Self-Attention可以采用上面的任何 score functions。在一些文章中也称为“intra-attention” 

Hu[7]对此分了个类:

前面谈到的一些Basic Attention给人的感觉能够从序列中根据权重分布提取重要元素。而Multi-dimensional Attention能够捕获不同表示空间中的term之间的多个交互,这一点简单的实现可以通过直接将多个单维表示堆叠在一起构建。Wang[8]等人提出了coupled multi-layer attentions,该模型属于多层注意力网络模型。作者称,通过这种多层方式,该模型可以进一步利用术语之间的间接关系,以获得更精确的信息。

3.1 Hierarchical(层次) Attention

再来看看Hierarchical Attention,Yang[9]等人提出了Hierarchical Attention Networks,看下面的图可能会更直观:

Hierarchical Attention Networks

这种结构能够反映文档的层次结构。模型在单词和句子级别分别设计了两个不同级别的注意力机制,这样做能够在构建文档表示时区别地对待这些内容。Hierarchical attention可以相应地构建分层注意力,自下而上(即,词级到句子级)或自上而下(词级到字符级),以提取全局和本地的重要信息。自下而上的方法上面刚谈完。那么自上而下又是如何做的呢?让我们看看Ji[10]等人的模型:

Nested Attention Hybrid Model

和机器翻译类似,作者依旧采用encoder-decoder架构,然后用word-level attention对全局语法和流畅性纠错,设计character-level attention对本地拼写错误纠正。

3.2 Self-Attention

那Self-Attention又是指什么呢?

Self-Attention(自注意力),也称为”intra-attention”(内部注意力),是关联单个序列的不同位置的注意力机制,以便计算序列的交互表示。 它已被证明在很多领域十分有效比如机器阅读,文本摘要或图像描述生成。

  • 比如Cheng[11]等人在机器阅读里面利用了自注意力。当前单词为红色,蓝色阴影的大小表示激活程度,自注意力机制使得能够学习当前单词和句子前一部分词之间的相关性。

当前单词为红色,蓝色阴影的大小表示激活程度

  • 比如Xu[12]等人利用自注意力在图像描述生成任务。注意力权重的可视化清楚地表明了模型关注的图像的哪些区域以便输出某个单词。

我们假设序列元素为 V=vi ,其匹配向量为 u 。让我们再来回顾下前面说的基本注意力的对齐函数,attention score通过 a(u,vi) 计算得到,由于是通过将外部 u 与每个元素 vi 匹配来计算注意力,所以这种形式可以看作是外部注意力。当我们把外部u替换成序列本身(或部分本身),这种形式就可以看作为内部注意力(internal attention)。

我们根据文章[7]中的例子来看看这个过程,例如句子:”Volleyball match is in progress between ladies”。句子中其它单词都依赖着”match”,理想情况下,我们希望使用自我注意力来自动捕获这种内在依赖。换句话说,自注意力可以解释为,每个单词 vi 去和V序列中的内部模式 v′ ,匹配函数 a(v′,vi) 。 v′ 很自然的选择为V中其它单词 vj ,这样遍可以计算成对注意力得分。为了完全捕捉序列中单词之间的复杂相互作用,我们可以进一步扩展它以计算序列中每对单词之间的注意力。这种方式让每个单词和序列中其它单词交互了关系。

另一方面,自注意力还可以自适应方式学习复杂的上下文单词表示。譬如经典文章:”A structured self-attentive sentence embedding”。这篇文章提出了一种通过引入自注意力机制来提取可解释句子嵌入的新模型。 使用二维矩阵而不是向量来代表嵌入,矩阵的每一行都在句子的不同部分,想深入了解的可以去看看这篇文章,另外,文章的公式感觉真的很漂亮。

值得一提还有2017年谷歌提出的Transformer[6],这是一种新颖的基于注意力的机器翻译架构,也是一个混合神经网络,具有前馈层和自注意层。论文的题目挺霸气:Attention is All you Need,毫无疑问,它是2017年最具影响力和最有趣的论文之一。那这篇文章的Transformer的庐山真面目到底是这样的呢?

这篇文章为提出许多改进,在完全抛弃了RNN的情况下进行seq2seq建模。接下来一起来详细看看吧。


Key, Value and Query:

众所周知,在NLP任务中,通常的处理方法是先分词,然后每个词转化为对应的词向量。接着一般最常见的有二类操作,第一类是接RNN(变体LSTM、GRU、SRU等),但是这一类方法没有摆脱时序这个局限,也就是说无法并行,也导致了在大数据集上的速度效率问题。第二类是接CNN,CNN方便并行,而且容易捕捉到一些全局的结构信息。很长一段时间都是以上二种的抉择以及改造,知道谷歌提供了第三类思路:纯靠注意力,也就是现在要讲的这个东东。

将输入序列编码表示视为一组键值对(K,V)以及查询 Q,因为文章取K=V=Q,所以也自然称为Self Attention。

K, V像是key-value的关系从而是一一对应的,那么上式的意思就是通过Q中每个元素query,与K中各个元素求内积然后softmax的方式,来得到Q中元素与V中元素的相似度,然后加权求和,得到一个新的向量。其中因子 n 为了使得内积不至于太大。以上公式在文中也称为点积注意力(scaled dot-product attention):输出是值的加权和,其中分配给每个值的权重由查询的点积与所有键确定

而Transformer主要由多头自注意力(Multi-Head Self-Attention)单元组成。 在NMT的上下文中,键和值都是编码器隐藏状态。 在解码器中,先前的输出被压缩成查询Q,并且通过映射该查询以及该组键和值来产生下一个输出。

3.3 Memory-based Attention

Memory-based Attention又是什么呢?我们先换种方式来看前面的注意力,假设有一系列的键值对 (ki,vi) 存在内存中和查询向量q,这样便能重写为以下过程:

这种解释是把注意力作为使用查询q的寻址过程,这个过程基于注意力分数从memory中读取内容。聪明的童鞋肯定已经发现了,如果我们假设ki=vi ,这不就是前面谈到的基础注意力么?然而,由于结合了额外的函数,可以实现可重用性和增加灵活性,所以Memory-based attention mechanism可以设计得更加强大。

那为什么又要这样做呢?在nlp的一些任务上比如问答匹配任务,答案往往与问题间接相关,因此基本的注意力技术就显得很无力了。那处理这一任务该如何做才好呢?这个时候就体现了Memory-based attention mechanism的强大了,譬如Sukhbaatar[18]等人通过迭代内存更新(也称为多跳)来模拟时间推理过程,以逐步引导注意到答案的正确位置:

在每次迭代中,使用新内容更新查询,并且使用更新的查询来检索相关内容。一种简单的更新方法为相加 qt+1=qt+ct 。那么还有其它更新方法么?当然有,直觉敏感的童鞋肯定想到了,光是这一点,就可以根据特定任务去设计,比如Kuma[13]等人的工作。这种方式的灵活度也体现在key和value可以自由的被设计,比如我们可以自由地将先验知识结合到key和value嵌入中,以允许它们分别更好地捕获相关信息。看到这里是不是觉得文章灌水其实也不是什么难事了。

3.4 Soft/Hard Attention

这个概念由《Show, Attend and Tell: Neural Image Caption Generation with Visual Attention》提出,这是对attention另一种分类。SoftAttention本质上和Bahdanau等人[3]很相似,其权重取值在0到1之间,而Hard Attention取值为0或者1。

3.5 Global/Local Attention

Luong等人[4]提出了Global Attention和Local Attention。Global Attention本质上和Bahdanau等人[3]很相似。Global方法顾名思义就是会关注源句子序列的所有词,具体地说,在计算语义向量时,会考虑编码器所有的隐藏状态。而在Local Attention中,计算语义向量时只关注每个目标词的一部分编码器隐藏状态。由于Global方法必须计算源句子序列所有隐藏状态,当句子长度过长会使得计算代价昂贵并使得翻译变得不太实际,比如在翻译段落和文档的时候。

参考文献

[1] Attention? Attention!.

[2] Neural Machine Translation (seq2seq) Tutorial.

[3] Neural Machine Translation by Jointly Learning to Align and Translate, Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. ICLR, 2015.

[4] Effective approaches to attention-based neural machine translation, Minh-Thang Luong, Hieu Pham, and Christopher D Manning. EMNLP, 2015.

[5] Neural Turing Machines, Alex Graves, Greg Wayne and Ivo Danihelka. 2014.

[6] Attention Is All You Need, Ashish Vaswani, et al. NIPS, 2017.

[7] An Introductory Survey on Attention Mechanisms in NLP Problems Dichao Hu, 2018.

[8] Coupled Multi-Layer Attentions for Co-Extraction of Aspect and Opinion Terms Wenya Wang,Sinno Jialin Pan, Daniel Dahlmeier and Xiaokui Xiao. AAAI, 2017.

[9] Hierarchical attention networks for document classification Zichao Yang et al. ACL, 2016.

[10] A Nested Attention Neural Hybrid Model for Grammatical Error Correction Jianshu Ji et al. 2017.

[11] Long Short-Term Memory-Networks for Machine Reading Jianpeng Cheng, Li Dong and Mirella Lapata. EMNLP, 2016.

[12] Show, Attend and Tell: Neural Image Caption Generation with Visual Attention Kelvin Xu et al. JMLR, 2015.

[13] Ask me anything: Dynamic memory networks for natural language processing. Zhouhan Lin al. JMLR, 2016.

[14] A structured self-attentive sentence embedding Zhouhan Lin al. ICLR, 2017.

[15] Learning Sentence Representation with Guidance of Human Attention Shaonan Wang , Jiajun Zhang, Chengqing Zong. IJCAI, 2017.

[16] Sequence to Sequence Learning with Neural Networks Ilya Sutskever et al. 2014.

[17] Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation Kyunghyun Cho, Yoshua Bengio et al. EMNLP, 2014.

[18] End-To-End Memory Networks Sainbayar Sukhbaatar et al. NIPS, 2015.

[19] 《Attention is All You Need》浅读(简介+代码)

pytorch如何加载不同尺寸的图片数据

如何使用dataloader加载相同维度但是不同尺寸的数据集(图片),不使用resize,crop等改变模型输入的shape。

知乎:https://www.zhihu.com/question/395888465

如果加载的数据的维度尺寸不相同的话,在迭代器中会爆出如下的错误

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0.

1、pytorch的dataloader默认的collate_fn会使用torch.stack合并多张图片成为batch

要么另外写一个collate_fn

要么在dataset类中对图片做padding,使得图片的size一样,可以直接stack

2、关于collate_fn:

 https://pytorch.org/docs/stable/data.html#working-with-collate-fn

The use of collate_fn is slightly different when automatic batching is enabled or disabled.

  • When automatic batching is disabledcollate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.
  • When automatic batching is enabledcollate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes behavior of the default collate_fn in this case.

可以看到,你可以考虑关闭自动打包,这样collate_fn处理的就是独立的样本。也可以打开自动打包,这样这个函数就会被输入一个batch列表的数据。注意,这个列表的数据可以不同大小哦,知识这样你就没办法将其stack成一个完整的batch。所以,实际上你的报错,应该是这个位置出的问题。

所以可以考虑以下几种策略:

  • 单个样本输入,这样同一个batch组合的时候就不需要担心了
  • 对输入样本padding成最大的形状,组合成batch,之后送入网络的时候,你可以把数据拆分开,按你想要的将其去掉padding或者其他操作
  • 正常读取,之后再自定义的collate_fn中将数据拆开返回,这样可以返回相同结构的数据

对于最后一点,给个小demo:

class OurDataset(Dataset):
    def __init__(self, *tensors):
        self.tensors = tensors
    def __getitem__(self, index):
        return self.tensors[index]
    def __len__(self):
        return len(self.tensors)

def collate_wrapper(batch):
#函数就会输入一个batch的列表的数据(注意是batch是一个列表,所以里面的数据可以不同大小)
    a, b = batch
    return a, b

a = torch.randn(3, 2, 3)
b = torch.randn(3, 3, 4)
dataset = OurDataset(a, b)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper)

for sample in loader:
    print([x.size() for x in sample])

# Out: [torch.Size([1, 3, 2, 3]), torch.Size([1, 3, 3, 4])]

Swin Transformer 代码详解

code:https://github.com/microsoft/Swin-Transformer

代码详解: https://zhuanlan.zhihu.com/p/367111046

预处理:

对于分类模型,输入图像尺寸为 224×224×3 ,即 H=W=224 。按照原文描述,模型先将图像分割成每块大小为 4×4 的patch,那么就会有 56×56 个patch,这就是初始resolution,也是后面每个stage会降采样的维度。后面每个stage都会降采样时长宽降到一半,特征数加倍。按照原文及原图描述,划分的每个patch具有 4×4×3=48 维特征。

  • 实际在代码中,首先使用了PatchEmbed模块(这里的PatchEmbed包括上图中的Linear Embedding 和 patch partition层),定义如下:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): # embed_dim就是上图中的C超参数
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

可以看到,实际操作使用了一个卷积层conv2d(3, 96, 4, 4),直接就做了划分patch和编码初始特征的工作,对于输入 x:B×3×224×224 ,经过一层conv2d和LayerNorm得到 x:B×562×96 。然后作为对比,可以选择性地加上每个patch的绝对位置编码,原文实验表示这种做法不好,因此不会采用(ape=false)。最后经过一层dropout,至此,预处理完成。另外,要注意的是,代码和上面流程图并不符,其实在stage 1之前,即预处理完成后,维度已经是 H/4×W/4×C ,stage 1之后已经是 H/8×W/8×2C ,不过在stage 4后不再降采样,得到的还是 H/32×W/32×8C 。

stage处理

我们先梳理整个stage的大体过程,把简单的部分先说了,再深入到复杂得的细节。每个stage,即代码中的BasicLayer,由若干个block组成,而block的数目由depth列表中的元素决定。每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA。在经历完一个stage后,会进行下采样,定义的下采样比较有意思。比如还是 56×56 个patch,四个为一组,分别取每组中的左上,右上、左下、右下堆叠一起,经过一个layernorm,linear层,实现维度下采样、特征加倍的效果。实际上它可以看成一种加权池化的过程。代码如下:

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

在经历完4个stage后,得到的是 (H/32×W/32)×8C 的特征,将其转到 8C×(H/32×W/32) 后,接一个AdaptiveAvgPool1d(1),全局平均池化,得到 8C 特征,最后接一个分类器。

PatchMerging

Block处理

SwinTransformerBlock的结构,由LayerNorm层、windowAttention层(Window MultiHead self -attention, W-MSA)、MLP层以及shiftWindowAttention层(SW-MSA)组成。

上面说到有两种block,block的代码如下:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        # 左图中最下边的LN层layerNorm层
        self.norm1 = norm_layer(dim)
        # W_MSA层或者SW-MSA层,详细的介绍看WindowAttention部分的代码
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # 左图中间部分的LN层
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 左图最上边的MLP层
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # 这里利用shift_size控制是否执行shift window操作
        # 当shift_size为0时,不执行shift操作,对应W-MSA,也就是在每个stage中,W-MSA与SW-MSA交替出现
        # 例如第一个stage中存在两个block,那么第一个shift_size=0就是W-MSA,第二个shift_size不为0
        # 就是SW-MSA
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
#slice() 函数实现切片对象,主要用在切片操作函数里的参数传递。class slice(start, stop[, step])
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
## 上述操作是为了给每个窗口给上索引

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        # 如果需要计算 SW-MSA就需要进行循环移位。
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
#shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

W-MSA

W-MSA比较简单,只要其中shift_size设置为0就是W-MSA。下面跟着代码走一遍过程。

  • 输入: x:B×562×96 , H,W=56
  • 经过一层layerNorm
  • 变形: x:B×56×56×96
  • 直接赋值给shifted_x
  • 调用window_partition函数,输入shifted_xwindow_size=7
  • 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
  • 这个函数对shifted_x做一系列变形,最终变成 82B×7×7×96
  • 返回赋值给x_windows,再变形成 82B×72×96 ,这表示所有图片,每个图片的64个window,每个window内有49个patch。
  • 调用WindowAttention层,这里以它的num_head为3为例。输入参数为x_windowsself.attn_mask,对于W-MSA,attn_mask为None,可以不用管。

WindowAttention代码如下:

代码中使用7×7的windowsize,将feature map分割为不同的window,在每个window中计算自注意力。

Self-attention的计算公式(B为相对位置编码)

绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数。

假设window_size = 2*2即每个窗口有4个token (M=2) ,如图1所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图6所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

坐标变换
坐标变换
相对位置索引求解流程图

最后生成的是相对位置索引,relative_position_index.shape = (M2,M2) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1)∗(2M−1) ,所以相对位置编码(relative position bias table)可以由一个3*3的矩阵表示,如图7所示:这样就根据index对应位置的索引找到table对应位置的值作为相对位置编码。

图7 相对位置编码

图7中的0-8为索引值,每个索引值都对应了 M2 维可学习数据(每个token都要计算 M2 个QK值,每个QK值都要加上对应的相对位置编码)

继续以图6中 M=2 的窗口为例,当计算位置1对应的 M2 个QK值时,应用的relative_position_index = [ 4, 5, 7, 8] (M2)个 ,对应的数据就是图7中位置索引4,5,7,8位置对应的 M2 维数据,即relative_position.shape = (M2∗M2)

相对位置编码在源码WindowAttention中应用,了解原理之后就很容易能够读懂程序:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim # 输入通道的数量
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH  初始化表

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1])  维度=Wh
        coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1])  维度=Ww

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww


        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1

        '''
        后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x\y坐标相加转换为一维偏移的时候
        他们的偏移量是相等的,所以需要对其做乘法操作,进行区分
        '''

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 计算得到相对位置索引
        # relative_position_index.shape = (M2, M2) 意思是一共有这么多个位置
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 

        '''
        relative_position_index注册为一个不参与网络学习的变量
        '''
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        '''
        使用从截断正态分布中提取的值填充输入张量
        self.relative_position_bias_table 是全0张量,通过trunc_normal_ 进行数值填充
        '''
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            N: number of all patches in the window
            C: 输入通过线性层转化得到的维度C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        '''
        x.shape = (num_windows*B, N, C)
        self.qkv(x).shape = (num_windows*B, N, 3C)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
        '''
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        '''
        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
        N = M2 代表patches的数量
        C//num_heads代表Q,K,V的维数
        '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # q乘上一个放缩系数,对应公式中的sqrt(d)
        q = q * self.scale

        # attn.shape = (num_windows*B, num_heads, N, N)  N = M2 代表patches的数量
        attn = (q @ k.transpose(-2, -1))

        '''
        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
        self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
        self.relative_position_index是计算出来不可学习的量
        '''
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        '''
        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 代表patches的数量
        .unsqueeze(0):扩张维度,在0对应的位置插入维度1
        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
        num_windows*B 通过广播机制传播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的维度1会broadcast到数量num_windows*B
        表示所有batch通用一个索引矩阵和相对位置矩阵
        '''
        attn = attn + relative_position_bias.unsqueeze(0)

        # mask.shape = (num_windows, M2, M2)
        # attn.shape = (num_windows*B, num_heads, M2, M2)
        if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        '''
        v.shape = (num_windows*B, num_heads, M2, C//num_heads)  N=M2 代表patches的数量, C//num_heads代表输入的维度
        attn.shape = (num_windows*B, num_heads, M2, M2)
        attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
        '''
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   # B_:num_windows*B  N:M2  C=num_heads*C//num_heads

        #   self.proj = nn.Linear(dim, dim)  dim = C
        #   self.proj_drop = nn.Dropout(proj_drop)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x  # x.shape = (num_windows*B, N, C)  N:窗口中所有patches的数量

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

在上述程序中有一段mask相关程序:

if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

这个部分对应的是Swin Transformer Block 中的SW-MSA

  • 输入 x:82B×72×96 。
  • 产生 QKV ,调用线性层后,得到 82B×72×(96×3) ,拆分给不同的head,得到 82B×72×3×3×32 ,第一个3是 QKV 的3,第二个3是3个head。再permute成 3×82B×3×72×32 ,再拆解成 q,k,v ,每个都是 82B×3×72×32 。表示所有图片的每个图片64个window,每个window对应到3个不同的head,都有一套49个patch、32维的特征。
  • q 归一化
  • qk 矩阵相乘求特征内积,得到 attn:82B×3×72×72
  • 得到相对位置的编码信息relative_position_bias
    • 代码如下:
self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
  • 这里以window_size=3为例,解释以下过程:首先生成 coords:2×3×3 ,就是在一个 3×3 的窗口内,每个位置的 y,x 坐标,而relative_coords为 2×9×9 ,就是9个点中,每个点的 y 或 x 与其他所有点的差值,比如 [0][3][1] 表示3号点(第二行第一个点)与1号点(第一行第二个点)的 y 坐标的差值。然后变形,并让两个坐标分别加上 3−1=2 ,是因为这些坐标值范围 [0,2] ,因此差值的最小值为-2,加上2后从0开始。最后让 y 坐标乘上 2×3−1=5 ,应该是一个trick,调整差值范围。最后将两个维度的差值相加,得到relative_position_index, 32×32 ,为9个点之间两两之间的相对位置编码值,最后用来到self.relative_position_bias_table中寻址,注意相对位置的最大值为 (2M−2)(2M−1) ,而这个table最多有 (2M−1)(2M−1) 行,因此保证可以寻址,得到了一组给多个head使用的相对位置编码信息,这个table是可训练的参数。
  • 回到代码中,得到的relative_position_bias为 3×72×72
  • 将其加到attn上,最后一个维度softmax,dropout
  • 与 v 矩阵相乘,并转置,合并多个头的信息,得到 82B×72×96
  • 经过一层线性层,dropout,返回
  • 返回赋值给attn_windows,变形为 82B×7×7×96
  • 调用window_reverse,打回原状: B×56×56×96
  • 返回给 x ,经过FFN:先加上原来的输入 x 作为residue结构,注意这里用到timmDropPath,并且drop的概率是整个网络结构线性增长的。然后再加上两层mlp的结果。
  • 返回结果 x 。

这样,整个过程就完成了,剩下的就是SW-MSA的一些不同的操作。

  1. 首先将windows进行半个窗口的循环移位,上图中的1, 2步骤,使用torch.roll实现。
  2. 在相同的窗口中计算自注意力,计算结果如下右图所示,window0的结构保存,但是针对window2的计算,其中3与3、6与6的计算生成了attn mask 中window2中的黄色区域,针对windows2中3与6、6与3之间不应该计算自注意力(attn mask中window2的蓝色区域),将蓝色区域mask赋值为-100,经过softmax之后,起作用可以忽略不计。同理window1与window3的计算一致。
  3. 最后再进行循环移位,恢复原来的位置。

原论文图中的Stage和程序中的一个Stage不同:

程序中的BasicLayer为一个Stage,在BasicLayer中调用了上面讲到的SwinTransformerBlock和PatchMerging模块:

class BasicLayer(nn.Module):  # 论文图中每个stage里对应的若干个SwinTransformerBlock
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth # swin_transformer blocks的个数
        self.use_checkpoint = use_checkpoint

        # build blocks  从0开始的偶数位置的SwinTransformerBlock计算的是W-MSA,奇数位置的Block计算的是SW-MSA,且shift_size = window_size//2
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)  # blk = SwinTransformerBlock
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

Part 3 : 不同视觉任务输出

程序中对应的是图片分类任务,经过Part 2 之后的数据通过 norm/avgpool/flatten:

 x = self.norm(x)  # B L C
 x = self.avgpool(x.transpose(1, 2))  # B C 1
 x = torch.flatten(x, 1) # B C

之后通过nn.Linear将特征转化为对应的类别:

self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

应用于其他不同的视觉任务时,只需要将输出进行特定的修改即可。

完整的SwinTransformer程序如下:

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes # 1000
        self.num_layers = len(depths) # [2, 2, 6, 2]  Swin_T 的配置
        self.embed_dim = embed_dim # 96
        self.ape = ape # False
        self.patch_norm = patch_norm # True
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))  # 96*2^3
        self.mlp_ratio = mlp_ratio # 4

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features) # norm_layer = nn.LayerNorm
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)  # 使用self.apply 初始化参数

    def _init_weights(self, m):
        # is_instance 判断对象是否为已知类型
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)  # x.shape = (H//4, W//4, C)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)  # self.pos_drop = nn.Dropout(p=drop_rate)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1) # B C
        return x

    def forward(self, x):
        x = self.forward_features(x)  # x是论文图中Figure 3 a图中最后的输出
        #  self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        x = self.head(x) # x.shape = (B, num_classes)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

补充:有关swin transformer相对位置编码:

torch.roll 函数

The Question about the mask of window attention:

https://github.com/microsoft/Swin-Transformer/issues/38

torch.roll(inputshiftsdims=None) → Tensor

Roll the tensor input along the given dimension(s). Elements that are shifted beyond the last position are re-introduced at the first position. If dims is None, the tensor will be flattened before rolling and then restored to the original shape.Parameters

  • input (Tensor) – the input tensor.
  • shifts (int or tuple of python:ints) – The number of places by which the elements of the tensor are shifted. If shifts is a tuple, dims must be a tuple of the same size, and each dimension will be rolled by the corresponding value
  • dims (int or tuple of python:ints) – Axis along which to roll

沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。

简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面

  • input (Tensor) —— 输入张量。
  • shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
  • dims (int 或 tuple of python:int) 确定的维度。

Example:

>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
>>> torch.roll(x, 1)
tensor([[8, 1],
        [2, 3],
        [4, 5],
        [6, 7]])

'''第0维度向下移1位,多出的[7,8]补充到顶部'''
>>> torch.roll(x, 1, 0)
tensor([[7, 8],
        [1, 2],
        [3, 4],
        [5, 6]])

'''第0维度向上移1位,多出的[1,2]补充到底部'''
>>> torch.roll(x, -1, 0)
tensor([[3, 4],
        [5, 6],
        [7, 8],
        [1, 2]])

'''tuple元祖,维度一一对应:
第0维度向下移2位,多出的[5,6][7,8]补充到顶部,
第1维向右移1位,多出的[6,8,2,4]补充到最左边'''
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],
        [8, 7],
        [2, 1],
        [4, 3]])

CPU怎么识别我们写的代码?

文章来源 图灵人工智能   转自STM32嵌入式开发,版权属于原作者,仅学术分享

先说一下半导体,啥叫半导体?就是介于导体和绝缘体中间的一种东西,比如二极管。相关文章:关于二极管的基础知识

电流可以从A端流向C端,但反过来则不行。你可以把它理解成一种防止电流逆流的东西。

当C端10V,A端0V,二极管可以视为断开。

当C端0V,A端10V,二极管可以视为导线,结果就是A端的电流源源不断的流向C端,导致最后的结果就是A端=C端=10V。

等等,不是说好的C端0V,A端10V么?咋就变成结果是A端=C端=10V了?你可以把这个理解成初始状态,当最后稳定下来之后就会变成A端=C端=10V。

文科的童鞋们对不住了,实在不懂问高中物理老师吧。反正你不能理解的话就记住这种情况下它相当于导线就行了。

利用半导体的这个特性,我们可以制作一些有趣的电路,比如【与门】。

我们把这个装置成为【与门】,把有电压的地方计为1,0电压的地方计为0。至于具体几V电压,那不重要。也就是AB必须同时输入1,输出端Y才是1;AB有一个是0,输出端Y就是0。

其他还有【或门】【非门】和【异或门】,跟这个都差不多,或门就是输入有一个是1输出就是1,输入00则输入0。

非门也好理解,就是输入1输出0,输入0输出1。

异或门难理解一些,不过也就那么回事,输入01或者10则输出1,输入00或者11则输出0。(即输入两个一样的值则输出0,输入两个不一样的值则输出1)。

这几种门都可以用二极管或者三极管做出来,具体怎么做就不演示了,有兴趣的童鞋可以自己试试。当然实际并不是用二极管三极管做的,因为它们太费电了。实际是用场效应管(也叫MOS管)做的。

然后我们就可以用门电路来做CPU了。当然做CPU还是挺难的,我们先从简单的开始:加法器。相关文章:CPU如何进行数字加法。加法器顾名思义,就是一种用来算加法的电路,最简单的就是下面这种。

AB只能输入0或者1,也就是这个加法器能算0+0,1+0或者1+1。

输出端S是结果,而C则代表是不是发生进位了,二进制1+1=10嘛。这个时候C=1,S=0。

费了大半天的力气,算个1+1是不是特别有成就感?

那再进一步算个1+2吧(二进制01+10),然后我们就发现了一个新的问题:第二位需要处理第一位有可能进位的问题,所以我们还得设计一个全加法器。

每次都这么画实在太麻烦了,我们简化一下。

也就是有3个输入2个输出,分别输入要相加的两个数和上一位的进位,然后输入结果和是否进位。然后我们把这个全加法器串起来:

我们就有了一个4位加法器,可以计算4位数的加法也就是15+15,已经达到了幼儿园中班水平,是不是特别给力?

做完加法器我们再做个乘法器吧,当然乘任意10进制数是有点麻烦的,我们先做个乘2的吧。

乘2就很简单了,对于一个2进制数数我们在后面加个0就算是乘2了。比如:

5=101(2)

10=1010(2)

以我们只要把输入都往前移动一位,再在最低位上补个零就算是乘2了。具体逻辑电路图我就不画,你们知道咋回事就行了。

那乘3呢?简单,先位移一次(乘2)再加一次。乘5呢?先位移两次(乘4)再加一次。

所以一般简单的CPU是没有乘法的,而乘法则是通过位移和加算的组合来通过软件来实现的。这说的有点远了,我们还是继续做CPU吧。

现在假设你有8位加法器了,也有一个位移1位的模块了。串起来你就能算(A+B)×2了!激动人心,已经差不多到了准小学生水平。

那我要是想算A×2+B呢?简单,你把加法器模块和位移模块的接线改一下就行了,改成输入A先过位移模块,再进加法器就可以了。

你的意思是我改个程序还得重新接线?

所以你以为呢?

实际上,编程就是把线来回插啊。惊喜不惊喜?意外不意外?

早期的计算机就是这样编程的,几分钟就算完了但插线好几天。关于插线编程的相关文章推荐看着篇:国内大神手工焊接,制作了一个CPU。而且插线是个细致且需要耐心的工作,所以那个时候的程序员都是清一色的漂亮女孩子,穿制服的那种,就像照片上这样。是不是有种生不逢时的感觉?

插线也是个累死人的工作。所以我们需要改进一下,让CPU可以根据指令来相加或者乘2。这里再引入两个模块,一个叫flip-flop,简称FF,中文好像叫触发器,如下图这样。

这个模块的作用是存储1bit数据。比如上面这个RS型的FF,R是Reset,输入1则清零。S是Set,输入1则保存1。RS都输入0的时候,会一直输出刚才保存的内容。

我们用FF来保存计算的中间数据(也可以是中间状态或者别的什么),1bit肯定是不够的,不过我们可以并联嘛,用4个或者8个来保存4位或者8位数据。这种我们称之为寄存器(Register)。另外一个叫MUX,中文叫选择器,如下图就是一个选择器。

这个就简单了,sel输入0则输出i0的数据,i0是什么就输出什么,01皆可。同理sel如果输入1则输出i1的数据。当然选择器可以做的很长,比如这种四进一出的具体原理不细说了,其实看看逻辑图琢磨一下就懂了,知道有这个东西就行了。下图是一个四进一出-选择器。

有这个东西我们就可以给加法器和乘2模块(位移)设计一个激活针脚。

这个激活针脚输入1则激活这个模块,输入0则不激活。这样我们就可以控制数据是流入加法器还是位移模块了。

于是我们给CPU先设计8个输入针脚,4位指令,4位数据。

我们再设计3个指令:

  • 0100,数据读入寄存器
  • 0001,数据与寄存器相加,结果保存到寄存器
  • 0010,寄存器数据向左位移一位(乘2)

为什么这么设计呢,刚才也说了,我们可以为每个模块设计一个激活针脚。然后我们可以分别用指令输入的第二第三第四个针脚连接寄存器,加法器和位移器的激活针脚。

这样我们输入0100这个指令的时候,寄存器输入被激活,其他模块都是0没有激活,数据就存入寄存器了。同理,如果我们输入0001这个指令,则加法器开始工作,我们就可以执行相加这个操作了。

这里就可以简单回答这个问题的第一个小问题了:CPU是为什么能看懂这些二级制的数呢?

为什么CPU能看懂,因为CPU里面的线就是这么接的呗。你输入一个二进制数,就像开关一样激活CPU里面若干个指定的模块以及改变这些模块的连同方式,最终得出结果。

几个可能会被问的问题

Q:CPU里面可能有成千上万个小模块,一个32位/64位的指令能控制那么多吗?

A:我们举例子的CPU里面只有3个模块,就直接接了。真正的CPU里会有一个解码器(decoder),把指令翻译成需要的形式。

Q:你举例子的简单CPU,如果我输入指令0011会怎么样?

A:当然是同时激活了加法器和位移器从而产生不可预料的后果,简单的说因为你使用了没有设计的指令,所以后果自负呗。在真正的CPU上这么干大概率就是崩溃呗,不过肯定会有各种保护性的设计。

细心的小伙伴可能发现一个问题:你设计的指令【0001,数据与寄存器相加,结果保存到寄存器】这个一步做不出来吧?

毕竟还有一个回写的过程,实际上确实是这样。我们设计的简易CPU执行一个指令差不多得三步,读取指令,执行指令,写寄存器。

经典的RISC设计则是分5步:读取指令(IF),解码指令(ID),执行指令(EX),内存操作(MEM),写寄存器(WB)。我们平常用的x86的CPU有的指令可能要分将近20个步骤。

你可以理解有这么一个开关,我们啪的按一下,CPU就走一步,你按的越快CPU就走的越快。咦?听说你有个想法?少年,你这个想法很危险啊,姑且不说你能不能按那么快。拿现代的CPU来说,也就2GHz多吧,大概一秒也就按个20亿下吧。

就算你能按那么快,虽然速度是上去了,但功耗会大大增加,发热上升稳定性下降。江湖上确实有这种玩法,名曰超频,不过新手不推荐你尝试哈。

那CPU怎么知道自己走到哪一步了呢?前面不是介绍了FF么,这个不光可以用来存中间数据,也可以用来存中间状态,也就是走到哪了。

具体的设计涉及到FSM(finite-state machine),也就是有限状态机理论,以及怎么用FF实装。这个也是很重要的一块,考试必考哈,只不过跟题目关系不大,这里就不展开讲了。

我们再继续刚才的讲,现在我们有3个指令了。我们来试试算个(1+4)X2+3吧。

0100 0001 ;寄存器存入1

0001 0100 ;寄存器的数字加4

0010 0000 ;乘2

0001 0011 ;再加三

太棒了,靠这台计算机我们应该可以打败所有的幼儿园小朋友,称霸大班了。而且现在我们用的是4位的,如果换成8位的CPU完全可以吊打低年级小学生了!

实际上用程序控制CPU是个挺高级的想法,再此之前计算机(器)的CPU都是单独设计的。

1969年一家日本公司BUSICOM想搞程控的计算器,而负责设计CPU的美国公司也觉得每次都重新设计CPU是个挺傻X的事,于是双方一拍即合,于1970年推出一种划时代的产品,世界上第一款微处理器4004。

这个架构改变了世界,那家负责设计CPU的美国公司也一步一步成为了业界巨头。哦对了,它叫Intel,对,就是噔噔噔噔的那个。

我们把刚才的程序整理一下:

“01000001000101000010000000010011”

你来把它输入CPU,我去准备一下去幼儿园大班踢馆的工作。

什么!?等我们输完了人家小朋友掰手指都能算出来了?

没办法机器语言就是这么反人类。哦,忘记说了,这种只有01组成的语言被称之为机器语言(机器码),是CPU唯一可以理解的语言。不过你把机器语言让人读,绝对一秒变典韦,这谁也受不了。

所以我们还是改进一下吧。不过话虽这么讲,也就往前个30年,直接输入01也是个挺普遍的事情。

于是我们把我们机器语言写成的程序:

0100 0001 ;寄存器存入1

0001 0100 ;寄存器的数字加4

0010 0000 ;乘2

0001 0011 ;再加三

改写成:

MOV 1 ;寄存器存入1

ADD 4 ;寄存器的数字加4

SHL 0 ;乘2(介于我们设计的乘法器暂时只能乘2,这个0是占位的)

ADD 3 ;再加三

是不是容易读多了?这就叫汇编语言。

汇编语言的好处在于它和机器语言一一对应。

也就是我们写的汇编可以完美的改写成机器语言,直接指挥cpu,进行底层开发;我们也可以把内存中的数据dump出来,以汇编语言的形式展示出来,方便调试和debug。

汇编语言极大的增强了机器语言的可读性和开发效率,但对于人类来说也依然是太晦涩了,于是我们又发明了高级语言,以近似于人类的语法来表现数据结构和算法。

比如很多语言都可以这么写:

a=(1+4)*2+3;

当然这样计算机是不认识的,我们要把它翻译成计算机认识的形式,这个过程叫编译,用来做这个事的东西叫编译器。

具体怎么把高级语言弄成汇编语言/机器语言的,一本书都写不完,我们就举个简单的例子。

我们把:

(1+4)*2+3

转换成:

1,4,+,2,*,3,+

这种写法叫后缀表示法,也成为逆波兰表示法。相对的,我们平常用的表示法叫中缀表示法,也就是符号方中间,比如1+4。而后缀表示法则写成1,4,+。

转换成这种写法的好处是没有先乘除后加减的影响,也没有括号了,直接算就行了。

具体怎么转换的可以找本讲编译原理的书看看,这里不展开讲了。

转换成这种形式之后我们就可以把它改成成汇编语言了。

从头开始处理,最开始是1,一个数字,那就存入寄存器:

MOV 1

之后是4,+,那就加一下:

ADD 4

然后是2,*,那就乘一下(介于我们设计的乘法器暂时只能乘2,这个0是占位的):

SHL 0

最后是3,+,那再加一下:

ADD 3

最后我们把翻译好的汇编整理一下:

MOV 1

ADD 4

SHL 0

ADD 3

再简单的转换成机器语言,就可以拿到我们设计的简单CPU上运行了。

其实到了这一步,应该把这个问题都讲清楚了:C语言写出来的东西是怎么翻译成二进制的,电脑又是怎么运行这个二进制的。

只不过题主最后还提到栈和硬件的关系,这里就再多说几句。

其实栈是一种数据结构,跟CPU无关。只不过栈这个数据结构实在太常用了,以至于CPU会针对性的进行优化。为了能让我们的CPU也能用栈,我们给它增加几个组件。

第一,增加一组寄存器。现在有两组寄存器了,我们分别成为A和B。

第二,增加两个指令,RDA/RDB和WRA/WRB,分别为把指定内存地址的数据读到寄存器A/B,和把寄存器A/B的内容写到指定地址。

顺便再说下内存,内存有个地址总线,有个数据总线。比如你要把1100这个数字存到0011这个地址,就把1100接到数据总线,0011接到地址总线,都准备好了啪嚓一按开关(对,就是我们前面提到的那个开关),就算是存进去了。

什么叫DDR内存呢,就是你按这个开关的时候存进去一个数字,抬起来之前你把地址和数据都更新一下,然后一松手,啪!又进去一个。也就是正常的内存你按一下进去1个数据,现在你按一下进去俩数据,这就叫双倍速率(Double Data Rate,简称DDR)

加了这几个命令之后我们发现按原来的设计,CPU每个指令针脚控制一个模块的方式的话针脚不够用了。所以我们就需要加一个解码器了(decoder)。

于是我们选择用第二个位作为是否选择寄存器的针脚。如果为0,则第三第四位可以正常激活位移器和加法器;如果为1则只激活寄存器而不激活位移和加法器,然后用第四位来决定是寄存器A还是B。这样变成了:

  • 0100,数据读入寄存器A
  • 0101,数据读入寄存器B (我们把汇编指令定义为MOVB)
  • 0001,数据与寄存器A相加,结果保存到寄存器A
  • 0011,数据与寄存器B相加,结果保存到寄存器B(我们把汇编指令定义为ADDB)
  • 0010,寄存器A数据向左位移一位(乘2)

最后我们可以用第一位来控制是不是进行内存操作。如果第一位为1则也不激活位移和加法器模块,然后用第三个针脚来控制是读还是写。这样就有了:

  • 1100,把寄存器B的地址数据读入寄存器A(我们把汇编指令定义为RD)
  • 1110,寄存器A的数据写到寄存器B指定的地址(我们把汇编指令定义为WR)

我们加了个解码器之后,加法器的激活条件从p4变成了(NOT (p1 OR p2)) AND p4。

加法器的输入则由第三个针脚判断,0则为寄存器A,1为寄存器B。这就是简单的指令解码啦。

当然我们也可以选择不向下兼容,另外设计一套指令。不过放到现实世界恐怕就要出大乱子了,所以你也可以想象我们平常用的x86背了个多大的历史包袱。

这个时候我们用栈的话,先栈地址初始化:

0101 1000 ; MOVB 16; 把栈底地址定义为1000

之后入栈的话,比如把数字3,4入栈:

1111 0011 ; WR 03; 把3写到内存,地址为1000

0011 0001 ; ADDB 01; 栈地址+1

1111 0100 ; WR 04; 把3写到内存,地址为1001

0011 0001 ; ADDB 01; 栈地址+1

这样就把3,4都保存到栈里了。

出栈的话反过来:

0011 1111 ; ADDB -1; 栈地址-1

1101 0000 ; RD 00; 把内容读入寄存器A,00是占位

0011 1111 ; ADDB -1; 栈地址-1

1101 0000 ; RD 00; 把内容读入寄存器A,00是占位

这样就依次得到4,3两个值。

所以,入栈出栈其实就是把数据写道指定的内存位置,CPU其实不知道你是在干啥。相关文章:关于C语言堆栈的经典讲解。当然我们也可以让CPU知道。

接下来我们再改进一下,给CPU再加一个寄存器SP,并定义两个指令:一个PUSH,一个POP。动作分别是把数据写入SP的地址,然后SP=SP+1,POP的话反过来。

这样有什么好处呢?好处在于PUSH/POP这样的指令消耗特别少,速度特别快。而栈这种数据结构在各种程序里用的又特别频繁,设计成专用的指令则可以很大程度上提升效率。

当然前提是编译器知道这个指令,并且做了优化,所以同样的程序(c语言写的),编译参数不一样(打开/关闭某些特性),编译出来的东西也就不一样,在不同硬件上的运行的效率也就会不一样。

比如上古时代的mmx,今天的SSE4.2,AVX-512,给力不给力?特别给力,但你平常用的程序支不支持是另一码事,要支持怎么办?重新编译呗。

这个时候开源的优势就显示出来了,重新编译很方便。闭源的话你就要指望作者开恩啦。

对于大多数人来说,电脑就是个黑箱,我们很难理解它到底是怎用工作的。这个问题又很难一句两句解释清楚,因为它是一环扣一环的,每一环都很抽象,每一环都是基础值俩个学分,展开了讲没上限的那种。

这就导致了即使是系统学过计算机的人也不见得就有一个明确而清晰的思路。想用尽量短的篇幅和尽量简单的语言把这个事从头到位解释了一下,希望能给大家解答一些疑惑。关于软硬件结合,另外也推荐下这篇文章:代码是如何控制硬件的?

空洞卷积

Multi-Scale Context Aggregation by Dilated Convolutions

一个简单的例子,[动态图来源:vdumoulin/conv_arithmetic]:

动图
Standard Convolution with a 3 x 3 kernel (and padding)
动图封面
Dilated Convolution with a 3 x 3 kernel and dilation rate 2

对于 dilated convolution, 我们已经可以发现他的优点,即内部数据结构的保留和避免使用 down-sampling 这样的特性。但是完全基于 dilated convolution 的结构如何设计则是一个新的问题。

潜在问题 1:The Gridding Effect

假设我们仅仅多次叠加 dilation rate 2 的 3 x 3 kernel 的话,则会出现这个问题:

我们发现我们的 kernel 并不连续,也就是并不是所有的 pixel 都用来计算了,因此这里将信息看做 checker-board 的方式会损失信息的连续性。这对 pixel-level dense prediction 的任务来说是致命的。

潜在问题 2:Long-ranged information might be not relevant.

我们从 dilated convolution 的设计背景来看就能推测出这样的设计是用来获取 long-ranged information。然而光采用大 dilation rate 的信息或许只对一些大物体分割有效果,而对小物体来说可能则有弊无利了。如何同时处理不同大小的物体的关系,则是设计好 dilated convolution 网络的关键。

通向标准化设计:Hybrid Dilated Convolution (HDC)

对于上个 section 里提到的几个问题,图森组的文章对其提出了较好的解决的方法。他们设计了一个称之为 HDC 的设计结构。

第一个特性是,叠加卷积的 dilation rate 不能有大于1的公约数。比如 [2, 4, 6] 则不是一个好的三层卷积,依然会出现 gridding effect。

第二个特性是,我们将 dilation rate 设计成 锯齿状结构,例如 [1, 2, 5, 1, 2, 5] 循环结构。

第三个特性是,我们需要满足一下这个式子: Mi=max[Mi+1−2ri,Mi+1−2(Mi+1−ri),ri]

其中 ri 是 i 层的 dilation rate 而 Mi 是指在 i 层的最大dilation rate,那么假设总共有n层的话,默认 Mn=rn 。假设我们应用于 kernel 为 k x k 的话,我们的目标则是 M2≤k ,这样我们至少可以用 dilation rate 1 即 standard convolution 的方式来覆盖掉所有洞。

一个简单的例子: dilation rate [1, 2, 5] with 3 x 3 kernel (可行的方案)

而这样的锯齿状本身的性质就比较好的来同时满足小物体大物体的分割要求(小 dilation rate 来关心近距离信息,大 dilation rate 来关心远距离信息)。

这样我们的卷积依然是连续的也就依然能满足VGG组观察的结论,大卷积是由小卷积的 regularisation 的 叠加。

代码:(绘制空洞卷积)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def dilated_conv_one_pixel(center: (int, int),feature_map: np.ndarray,k: int = 3,r: int = 1,v: int = 1):
    """
    膨胀卷积核中心在指定坐标center处时,统计哪些像素被利用到,
    并在利用到的像素位置处加上增量v
    Args:
        center: 膨胀卷积核中心的坐标
        feature_map: 记录每个像素使用次数的特征图
        k: 膨胀卷积核的kernel大小
        r: 膨胀卷积的dilation rate
        v: 使用次数增量
    """
    assert divmod(3, 2)[1] == 1

    # left-top: (x, y)
    left_top = (center[0] - ((k - 1) // 2) * r, center[1] - ((k - 1) // 2) * r)
    for i in range(k):
        for j in range(k):
            feature_map[left_top[1] + i * r][left_top[0] + j * r] += v


def dilated_conv_all_map(dilated_map: np.ndarray,
                         k: int = 3,
                         r: int = 1):
    """
    根据输出特征矩阵中哪些像素被使用以及使用次数,
    配合膨胀卷积k和r计算输入特征矩阵哪些像素被使用以及使用次数
    Args:
        dilated_map: 记录输出特征矩阵中每个像素被使用次数的特征图
        k: 膨胀卷积核的kernel大小
        r: 膨胀卷积的dilation rate
    """
    new_map = np.zeros_like(dilated_map)
    for i in range(dilated_map.shape[0]):
        for j in range(dilated_map.shape[1]):
            if dilated_map[i][j] > 0:
                dilated_conv_one_pixel((j, i), new_map, k=k, r=r, v=dilated_map[i][j])

    return new_map


def plot_map(matrix: np.ndarray):
    plt.figure()

    c_list = ['white', 'blue', 'red']
    new_cmp = LinearSegmentedColormap.from_list('chaos', c_list)
    plt.imshow(matrix, cmap=new_cmp)

    ax = plt.gca()
    ax.set_xticks(np.arange(-0.5, matrix.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, matrix.shape[0], 1), minor=True)

    # 显示color bar
    plt.colorbar()

    # 在图中标注数量
    thresh = 5
    for x in range(matrix.shape[1]):
        for y in range(matrix.shape[0]):
            # 注意这里的matrix[y, x]不是matrix[x, y]
            info = int(matrix[y, x])
            ax.text(x, y, info,
                    verticalalignment='center',
                    horizontalalignment='center',
                    color="white" if info > thresh else "black")
    ax.grid(which='minor', color='black', linestyle='-', linewidth=1.5)
    plt.show()
    plt.close()


def main():
    # bottom to top
    dilated_rates = [1, 2, 3]
    # init feature map
    size = 31
    m = np.zeros(shape=(size, size), dtype=np.int32)
    center = size // 2
    m[center][center] = 1
    # print(m)
    # plot_map(m)

    for index, dilated_r in enumerate(dilated_rates[::-1]):
        new_map = dilated_conv_all_map(m, r=dilated_r)
        m = new_map
    print(m)
    plot_map(m)

绘制结果:

VIT

Dosovitskiy et al. An image is worth 16×16 words: transformers for image recognition at scale. In ICLR, 2021

step1 :分割图片

step2 向量化:从九个快变成九个向量

step3:向量线性变换:(linear embedding线性嵌入层)

step4:将位置编码添加到z上:

step4:添加一个cls向量:

step5:只利用cls的输出

按照上面的流程图,一个ViT block可以分为以下几个步骤

(1) patch embedding:例如输入图片大小为224×224,将图片分为固定大小的patch,patch大小为16×16,则每张图像会生成224×224/16×16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196×768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197×768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197×768

(3) LN/multi-head attention/LN:LN输出维度依然是197×768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197×768,如果有12个头(768/12=64),则qkv的维度是197×64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197×768,然后在过一层LN,维度依然是197×768

(4) MLP:将维度放大再缩小回去,197×768放大为197×3072,再缩小变为197×768

一个block之后维度依然和输入相同,都是197×768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 zL0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit需要预训练+微调

• Pretrain the model on Dataset A, fine-tune the model on Dataset B,
and evaluate the model on Dataset B.
• Pretrained on ImageNet (small), ViT is slightly worse than ResNet.
• Pretrained on ImageNet-21K (medium), ViT is comparable to ResNet.
• Pretrained on JFT (large), ViT is slightly better than ResNet.

效果: