AS-MLP:首个检测与分割领域MLP架构

paper: https://arxiv.org/abs/2107.08391

github:https://github.com/svip-lab/AS-MLP

本文是上海科技大学在MLP架构方面的探索,它设计了一种轴向移位操作以便于进行空间信息交互。在架构方面,AS-MLP采用了类似PVT的分层架构,因为可以轻易的迁移到下游任务。所提方法在ImageNet数据集上取得了优于其他MLP架构的性能,在COC检测与ADE20K分割任务上取得了与Swin相当的性能。值得一提的是,AS-MLP是首个迁移到下游任务的MLP架构。注:CycleMLP与AS-MLP属于同一时期的工作,发到arxiv的时间也只差两天,说两者都是首个其实也可以。

本文提出了一种轴向移动架构AS-MLP(Axial Shifted MLP)用于不同的视觉任务(包含图像分类、检测以及分割)。不同于MLP-Mixer通过矩阵转置+词混叠MLP进行全局空域特征编码,我们在局部特征通信方向投入了更多的关注。

通过轴向移动特征信息,AS-MLP可以得到不同方向的信息流,这有助于捕获局部相关性。该操作使得我们采用纯MLP架构即可取得与CNN相同的感受野。我们还可以类似卷积核设置AS-MLP模块的感受野尺寸以及扩张因子。如此简单而有效的架构取得了优于其他MLP架构的性能,同时具有与Transformer架构(比如Swin Transformer)相当的性能,甚至具有稍少的FLOPs。比如,AS-MLP在ImageNet数据集上凭借88M参数量+15.2GFLOPs取得了83.3%top1精度,且无需额外训练数据。

此外,所提AS-MLP也是首个用于下游任务(如目标检测、语义分割)的MLP架构。AS-MLP在COC验证集上取得了51.5mAP指标,在ADE20K数据集上取得了49.5mIoU指标,具有与Transformer架构相当的性能。

Method:

Comparisons between AS-MLP, Convolution, Transformer and MLP-Mixer

在这里,我们将AS-MLP、卷积、Swin以及MLP-Mixer进行对比分析。尽管这些模型是从不同角度出发设计得到,但它们均基于给定输出位置点,其值依赖于局部特征的加权。这些采样位置包含局部依赖与长距离依赖。

从上述对比图可以看到:

  • 卷积是一种局部感受野的操作,更适合于提取具有局部依赖关系的特征;
  • Swin同样是一种局部感受野操作,Swin为自注意力机制引入了局部性提升了Transformer架构的性能,同时也降低了计算复杂度;
  • MLP-Mixer是一种全局感受野操作,它仅仅由矩阵转置与MLP操作构成;
  • AS-MLP是一种局部“十”字感受野操作,它可以更好的提取局部依赖关系。

Variants of AS-MLP Architecture

前面的Figure仅仅给出了Tiny版本的AS-MLP架构,参考DeiT与Swin,我们通过调整模块数与通道数构建了不同大小的模型。

image.png

Experiments

ImageNet Classification

image.png

上表给出了所提方法在ImageNet数据上的性能对比,从中可以看到:

  • 所提AS-MLP取得了比其他MLP架构更优的性能,同时具有相似的参数量与FLOPs;
  • AS-MLP-S取得了83.1%的top1精度同时具有比Mixer-B/16、ViP-Medium/7更少的参数量;
  • 此外,AS-MLP-B取得了与Swin相当的性能:83.3%。
image.png

此外,我们还对比了端侧配置版本的AS-MLP,结果见上表。可以看到:在端侧配置下,所提方法大幅超越了Swin Transformer。

COCO Detection

image.png

上表对比了COCO检测任务上的性能对比,可以看到:

  • 所提AS-MLP是首个用于下游任务的MLP架构;
  • 所提AS-MLP取得了与Swin相当的性能。具体来说,在Cascade Mask R-CNN+Swin-B取得了51.9AP指标,参数量为145M;而AS-MLP-B取得了51。5AP指标,参数量为145M。

ADE20K Segmentation

image.png

上表给出了ADE20K分割任务上的性能对比,从中可以看到:

  • 所提AS-MLP同样是首个用于分割任务的MLP架构;
  • AS-MLP-T取得了比Swin-T等有的性能,同时具有稍少FLOPs;
  • UperNet+Swin-B取得了49.7mIoU,参数量为121M,计算量为1188GFLOPs;而UperNet+AS-MLP-B取得了49.5mIoU,参数量121M,计算量为1166GFLOPs。

Ablation Study

AS-MLP的核心是轴向移动,接下来我们将对其不同成分进行消融分析,所有试验均基于AS-MLP-T实现。

image.png

上表对比了不同padding方式、不同移动尺寸以及不同扩展比例的性能对比,从中可以看到:

  • zero-padding更适合于AS-MLP设计;
  • 提升扩张因子会轻微降低模型性能;
  • 提升移动尺寸,模型精度会先上升后下降。
  • 基于上述分析,我们采用shift=5,zero-padding,dilation=1。

image.png
我们同时还比较了AS-MLP模块的不同链接类型,结果见上表,从中可以看到:在不同移动尺寸下,并行连接总是具有比串行连接更佳性能

Comparsion with S2MLP

在初看到该文时,第一感觉这个与百度的那篇S2MLP(见下图核心模块)真的非常相似,都是采用了垂直、水平移位方式进行空间信息交互,而且还都是上下左右四个方向。可惜AS-MLP并未与S2MLP进行对比,反而比较晚(指的是见刊arxiv)的ViP进行的对比。

image.png

既然提到了,我们还是对S2MLP与ASMLP进行一下对比吧。

  • 在整体架构方面,AS-MLP采用了类似PVT的分层架构,而S2MLP一文则是采用了类似ViT的柱状架构;
  • 在应用方面,AS-MLP即可应用于图像分类,还可以迁移到下游任务中;而S2MLP则仅适用于图像分类,并不适用下游任务;
  • 在核心模型方面,AS-MLP采用并行垂直、水平移动,分别进行特征汇聚后再进行特征相加汇聚;而S2MLP则采用分组方式,不同组进行不同方向的移动,然后再进行空间信息汇聚;
  • 在模型性能方面,AS-MLP取得了与Swin相当的性能,比ViP更优的性能;而S2MLP的性能则弱于Swin与ViP;
  • 最后一点,AS-MLP开源了,但S2MLP并未开源。

自然语言处理中注意力机制综述

注意力汇聚

查询(自主提示)和键(非自主提示)之间的交互形成了 注意力汇聚(attentionpooling)。注意力汇聚有选择地聚合了值(感官输入)以生成最终的输
出。注意力汇聚(attention pooling)公式:

其中 x 是查询,(xi; yi) 是键值对。注意力汇聚是 yi 的加权平均。将查询 x 和键 xi之间的关系建模为 注意力权重(attetnion weight) (x; xi),如 (10.2.4) 所示,这个权重将被分配给每一个对应值 yi。对于任何查询,模型在所有键值对上的注意力权重都是一个有效的概率分布:它们是非负数的,并且总和为1。

正如我们所看到的,选择不同的注意力评分函数 a 会导致不同的注意力汇聚操作。

1、加性注意力

2、缩放点积注意力

使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的⻓度d。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为 0 和方差为 1。那么两个向量的点积的均值为 0,方差为 d。为确保无论向量⻓度如何,点积的方差在不考虑向量⻓度的情况下仍然是 1,则可以使用 缩放点积注意力(scaled dot-product attention)评分函数:

1. 写在前面

近些年来,注意力机制一直频繁的出现在目之所及的文献或者博文中,可见在nlp中算得上是个相当流行的概念,事实也证明其在nlp领域散发出不小得作用。这几年得顶会paper就能看出这一点。本文深入浅出地介绍了自然语言处理中的注意力机制技术。据Lilian Weng博主总结以及一些资料显示,Attention机制最早应该是在视觉图像领域提出来的,这方面的工作应该很多,历史也比较悠久。人类的视觉注意力虽然存在很多不同的模型,但它们都基本上归结为给予需要重点关注的目标区域(注意力焦点)更重要的注意力,同时给予周围的图像低的注意力,然后随着时间的推移调整焦点。而直到Bahdanau等人发表了论文《Neural Machine Translation by Jointly Learning to Align and Translate》,该论文使用类似attention的机制在机器翻译任务上将翻译和对齐同时进行,这个工作目前是最被认可为是第一个提出attention机制应用到NLP领域中,值得一提的是,该论文2015年被ICLR录用,截至现在,谷歌引用量为5596,可见后续nlp在这一块的研究火爆程度。

注意力机制首先从人类直觉中得到,在nlp领域的机器翻译任务上首先取得不错的应用成功。简而言之,深度学习中的注意力可以广义地解释为重要性权重的向量:为了预测一个元素,例如句子中的单词,使用注意力向量来估计它与其他元素的相关程度有多强,并将其值的总和作为目标的近似值

既然注意力机制最早在nlp领域应用于机器翻译任务,那在这个之前又是怎么做的呢?传统的基于短语的翻译系统通过将源句分成多个块然后逐个词地翻译它们来完成它们的任务。 这导致了翻译输出的不流畅。想想我们人类是如何翻译的?我们首先会阅读整个待翻译的句子,然后结合上下文理解其含义,最后产生翻译。在某种程度上,神经机器翻译(NMT)的提出正是想去模仿这一过程。而在NMT的翻译模型中经典的做法是由编码器 – 解码器架构制定(encoder-decoder),用作encoder和decoder常用的是循环神经网络。这类模型大概过程是首先将源句子的输入序列送入到编码器中,提取最后隐藏的表示并用于初始化解码器的隐藏状态,然后一个接一个地生成目标单词,这个过程广义上可以理解为不断地将前一个时刻 t-1 的输出作为后一个时刻 t 的输入,循环解码,直到输出停止符为止。通过这种方式,NMT解决了传统的基于短语的方法中的局部翻译问题:它可以捕获语言中的长距离依赖性,并提供更流畅的翻译。但是这样做也存在很多缺点,譬如,RNN是健忘的,这意味着前面的信息在经过多个时间步骤传播后会被逐渐消弱乃至消失。其次,在解码期间没有进行对齐操作,因此在解码每个元素的过程中,焦点分散在整个序列中。对于前面那个问题,LSTM、GRU在一定程度能够缓解。而后者正是Bahdanau等人重视的问题。

 2、NLP中Attention mechanism的起源

在Seq2Seq结构中,encoder把所有的输入序列都编码成一个统一的语义向量context,然后再由decoder解码。而context自然也就成了限制模型性能的瓶颈。譬如机器翻译问题,当要翻译的句子较长时,一个context可能存不下那么多信息。除此之外,只用编码器的最后一个隐藏层状态,感觉上都不是很合理。实际上当我们翻译的时候譬如:Source:机器学习–>Target:machine learning。当decoder要生成”machine”的时候,应该更关注”机器”,而生成”learning”的时候,应该给予”学习”更大的权重。所以如果要改进Seq2Seq结构,一个不错的想法自然就是利用encoder所有隐藏层状态解决context限制问题。

Bahdanau等人把attention机制用到了神经网络机器翻译(NMT)上。传统的encoder-decoder模型通过encoder将Source序列编码到一个固定维度的中间语义向量context,然后在使用decoder进行解码翻译到目标语言序列。前面谈到了这种做法的局限性,而且,Bahdanau等人在摘要也说到这个context可能是提高这种基本编码器 – 解码器架构性能的瓶颈,那Bahdanau等人又是如何尝试缓解这个问题的呢?让我们来一探究竟,作者为了缓解中间向量context很难将Source序列所有必要信息压缩进来的问题,特别是对于那些很长的句子。提出在机器翻译任务上在 encoder–decoder 做出了如下扩展:将翻译和对齐联合学习。这个操作在生成Target序列的每个词时,用到的中间语义向量context是Source序列通过encoder的隐藏层的加权和,而传统的做法是只用encoder最后一个输出 ht 作为context,这样就能保证在解码不同词的时候,Source序列对现在解码词的贡献是不一样的。想想前面那个例子:”Source:机器学习–>Target:machine learning”(假如中文按照字切分)。decoder在解码”machine”时,”机”和”器”提供的权重要更大一些,同样,在解码”learning”时,”学”和”习”提供的权重相应的会更大一些,这在直觉也和人类翻译也是一致的。通过这种attention的设计,作者将Source序列的每个词(通过encoder的隐藏层输出)和Target序列(当前要翻译的词)的每个词巧妙的建立了联系。想一想,翻译每个词的时候,都有一个语义向量,而这个语义向量是Source序列每个词通过encoder之后的隐藏层的加权和。 由此可以得到一个Source序列和Target序列的对齐矩阵,通过可视化这个矩阵,可以看出在翻译一个词的时候,Source序列的每个词对当前要翻译词的重要性分布,这在直觉上也能给人一种可解释性的感觉。

3. NLP中的注意力机制

随着注意力机制的广泛应用,在某种程度上缓解了源序列和目标序列由于距离限制而难以建模依赖关系的问题。现在已经涌现出了一大批基于基本形式的注意力的不同变体来处理更复杂的任务。让我们一起来看看其在不同NLP问题中的注意力机制。

其实我们可能已经意识到了,对齐模型的设计不是唯一的,确实,在某种意义上说,根据不同的任务设计适应于特定任务的对齐模型可以看作设计出了新的attention变体,让我们再看看这个模型(函数): score(st,hi) 。再来看几个代表性的work。

Citation等人提出Content-base attention,其对齐函数模型设计为:

Bahdanau等人的Additive(*),其设计为:

Luong[4]等人文献包含了几种方式:

以及Luong等人还尝试过location-based function。这种方法的对齐分数仅从目标隐藏状态学习得到。

  • Vaswani[6]等人的Scaled Dot-Product(^)缩放点积注意:

细心的童鞋可能早就发现了这东东和点积注意力很像,只是加了个scale factor。当输入较大时,softmax函数可能具有极小的梯度,难以有效学习,所以作者加入比例因子 1/n 。

Cheng[7]等人的Self-Attention(&)可以关联相同输入序列的不同位置。 从理论上讲,Self-Attention可以采用上面的任何 score functions。在一些文章中也称为“intra-attention” 

Hu[7]对此分了个类:

前面谈到的一些Basic Attention给人的感觉能够从序列中根据权重分布提取重要元素。而Multi-dimensional Attention能够捕获不同表示空间中的term之间的多个交互,这一点简单的实现可以通过直接将多个单维表示堆叠在一起构建。Wang[8]等人提出了coupled multi-layer attentions,该模型属于多层注意力网络模型。作者称,通过这种多层方式,该模型可以进一步利用术语之间的间接关系,以获得更精确的信息。

3.1 Hierarchical(层次) Attention

再来看看Hierarchical Attention,Yang[9]等人提出了Hierarchical Attention Networks,看下面的图可能会更直观:

Hierarchical Attention Networks

这种结构能够反映文档的层次结构。模型在单词和句子级别分别设计了两个不同级别的注意力机制,这样做能够在构建文档表示时区别地对待这些内容。Hierarchical attention可以相应地构建分层注意力,自下而上(即,词级到句子级)或自上而下(词级到字符级),以提取全局和本地的重要信息。自下而上的方法上面刚谈完。那么自上而下又是如何做的呢?让我们看看Ji[10]等人的模型:

Nested Attention Hybrid Model

和机器翻译类似,作者依旧采用encoder-decoder架构,然后用word-level attention对全局语法和流畅性纠错,设计character-level attention对本地拼写错误纠正。

3.2 Self-Attention

那Self-Attention又是指什么呢?

Self-Attention(自注意力),也称为”intra-attention”(内部注意力),是关联单个序列的不同位置的注意力机制,以便计算序列的交互表示。 它已被证明在很多领域十分有效比如机器阅读,文本摘要或图像描述生成。

  • 比如Cheng[11]等人在机器阅读里面利用了自注意力。当前单词为红色,蓝色阴影的大小表示激活程度,自注意力机制使得能够学习当前单词和句子前一部分词之间的相关性。

当前单词为红色,蓝色阴影的大小表示激活程度

  • 比如Xu[12]等人利用自注意力在图像描述生成任务。注意力权重的可视化清楚地表明了模型关注的图像的哪些区域以便输出某个单词。

我们假设序列元素为 V=vi ,其匹配向量为 u 。让我们再来回顾下前面说的基本注意力的对齐函数,attention score通过 a(u,vi) 计算得到,由于是通过将外部 u 与每个元素 vi 匹配来计算注意力,所以这种形式可以看作是外部注意力。当我们把外部u替换成序列本身(或部分本身),这种形式就可以看作为内部注意力(internal attention)。

我们根据文章[7]中的例子来看看这个过程,例如句子:”Volleyball match is in progress between ladies”。句子中其它单词都依赖着”match”,理想情况下,我们希望使用自我注意力来自动捕获这种内在依赖。换句话说,自注意力可以解释为,每个单词 vi 去和V序列中的内部模式 v′ ,匹配函数 a(v′,vi) 。 v′ 很自然的选择为V中其它单词 vj ,这样遍可以计算成对注意力得分。为了完全捕捉序列中单词之间的复杂相互作用,我们可以进一步扩展它以计算序列中每对单词之间的注意力。这种方式让每个单词和序列中其它单词交互了关系。

另一方面,自注意力还可以自适应方式学习复杂的上下文单词表示。譬如经典文章:”A structured self-attentive sentence embedding”。这篇文章提出了一种通过引入自注意力机制来提取可解释句子嵌入的新模型。 使用二维矩阵而不是向量来代表嵌入,矩阵的每一行都在句子的不同部分,想深入了解的可以去看看这篇文章,另外,文章的公式感觉真的很漂亮。

值得一提还有2017年谷歌提出的Transformer[6],这是一种新颖的基于注意力的机器翻译架构,也是一个混合神经网络,具有前馈层和自注意层。论文的题目挺霸气:Attention is All you Need,毫无疑问,它是2017年最具影响力和最有趣的论文之一。那这篇文章的Transformer的庐山真面目到底是这样的呢?

这篇文章为提出许多改进,在完全抛弃了RNN的情况下进行seq2seq建模。接下来一起来详细看看吧。


Key, Value and Query:

众所周知,在NLP任务中,通常的处理方法是先分词,然后每个词转化为对应的词向量。接着一般最常见的有二类操作,第一类是接RNN(变体LSTM、GRU、SRU等),但是这一类方法没有摆脱时序这个局限,也就是说无法并行,也导致了在大数据集上的速度效率问题。第二类是接CNN,CNN方便并行,而且容易捕捉到一些全局的结构信息。很长一段时间都是以上二种的抉择以及改造,知道谷歌提供了第三类思路:纯靠注意力,也就是现在要讲的这个东东。

将输入序列编码表示视为一组键值对(K,V)以及查询 Q,因为文章取K=V=Q,所以也自然称为Self Attention。

K, V像是key-value的关系从而是一一对应的,那么上式的意思就是通过Q中每个元素query,与K中各个元素求内积然后softmax的方式,来得到Q中元素与V中元素的相似度,然后加权求和,得到一个新的向量。其中因子 n 为了使得内积不至于太大。以上公式在文中也称为点积注意力(scaled dot-product attention):输出是值的加权和,其中分配给每个值的权重由查询的点积与所有键确定

而Transformer主要由多头自注意力(Multi-Head Self-Attention)单元组成。 在NMT的上下文中,键和值都是编码器隐藏状态。 在解码器中,先前的输出被压缩成查询Q,并且通过映射该查询以及该组键和值来产生下一个输出。

3.3 Memory-based Attention

Memory-based Attention又是什么呢?我们先换种方式来看前面的注意力,假设有一系列的键值对 (ki,vi) 存在内存中和查询向量q,这样便能重写为以下过程:

这种解释是把注意力作为使用查询q的寻址过程,这个过程基于注意力分数从memory中读取内容。聪明的童鞋肯定已经发现了,如果我们假设ki=vi ,这不就是前面谈到的基础注意力么?然而,由于结合了额外的函数,可以实现可重用性和增加灵活性,所以Memory-based attention mechanism可以设计得更加强大。

那为什么又要这样做呢?在nlp的一些任务上比如问答匹配任务,答案往往与问题间接相关,因此基本的注意力技术就显得很无力了。那处理这一任务该如何做才好呢?这个时候就体现了Memory-based attention mechanism的强大了,譬如Sukhbaatar[18]等人通过迭代内存更新(也称为多跳)来模拟时间推理过程,以逐步引导注意到答案的正确位置:

在每次迭代中,使用新内容更新查询,并且使用更新的查询来检索相关内容。一种简单的更新方法为相加 qt+1=qt+ct 。那么还有其它更新方法么?当然有,直觉敏感的童鞋肯定想到了,光是这一点,就可以根据特定任务去设计,比如Kuma[13]等人的工作。这种方式的灵活度也体现在key和value可以自由的被设计,比如我们可以自由地将先验知识结合到key和value嵌入中,以允许它们分别更好地捕获相关信息。看到这里是不是觉得文章灌水其实也不是什么难事了。

3.4 Soft/Hard Attention

这个概念由《Show, Attend and Tell: Neural Image Caption Generation with Visual Attention》提出,这是对attention另一种分类。SoftAttention本质上和Bahdanau等人[3]很相似,其权重取值在0到1之间,而Hard Attention取值为0或者1。

3.5 Global/Local Attention

Luong等人[4]提出了Global Attention和Local Attention。Global Attention本质上和Bahdanau等人[3]很相似。Global方法顾名思义就是会关注源句子序列的所有词,具体地说,在计算语义向量时,会考虑编码器所有的隐藏状态。而在Local Attention中,计算语义向量时只关注每个目标词的一部分编码器隐藏状态。由于Global方法必须计算源句子序列所有隐藏状态,当句子长度过长会使得计算代价昂贵并使得翻译变得不太实际,比如在翻译段落和文档的时候。

参考文献

[1] Attention? Attention!.

[2] Neural Machine Translation (seq2seq) Tutorial.

[3] Neural Machine Translation by Jointly Learning to Align and Translate, Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. ICLR, 2015.

[4] Effective approaches to attention-based neural machine translation, Minh-Thang Luong, Hieu Pham, and Christopher D Manning. EMNLP, 2015.

[5] Neural Turing Machines, Alex Graves, Greg Wayne and Ivo Danihelka. 2014.

[6] Attention Is All You Need, Ashish Vaswani, et al. NIPS, 2017.

[7] An Introductory Survey on Attention Mechanisms in NLP Problems Dichao Hu, 2018.

[8] Coupled Multi-Layer Attentions for Co-Extraction of Aspect and Opinion Terms Wenya Wang,Sinno Jialin Pan, Daniel Dahlmeier and Xiaokui Xiao. AAAI, 2017.

[9] Hierarchical attention networks for document classification Zichao Yang et al. ACL, 2016.

[10] A Nested Attention Neural Hybrid Model for Grammatical Error Correction Jianshu Ji et al. 2017.

[11] Long Short-Term Memory-Networks for Machine Reading Jianpeng Cheng, Li Dong and Mirella Lapata. EMNLP, 2016.

[12] Show, Attend and Tell: Neural Image Caption Generation with Visual Attention Kelvin Xu et al. JMLR, 2015.

[13] Ask me anything: Dynamic memory networks for natural language processing. Zhouhan Lin al. JMLR, 2016.

[14] A structured self-attentive sentence embedding Zhouhan Lin al. ICLR, 2017.

[15] Learning Sentence Representation with Guidance of Human Attention Shaonan Wang , Jiajun Zhang, Chengqing Zong. IJCAI, 2017.

[16] Sequence to Sequence Learning with Neural Networks Ilya Sutskever et al. 2014.

[17] Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation Kyunghyun Cho, Yoshua Bengio et al. EMNLP, 2014.

[18] End-To-End Memory Networks Sainbayar Sukhbaatar et al. NIPS, 2015.

[19] 《Attention is All You Need》浅读(简介+代码)

Swin Transformer 代码详解

code:https://github.com/microsoft/Swin-Transformer

代码详解: https://zhuanlan.zhihu.com/p/367111046

预处理:

对于分类模型,输入图像尺寸为 224×224×3 ,即 H=W=224 。按照原文描述,模型先将图像分割成每块大小为 4×4 的patch,那么就会有 56×56 个patch,这就是初始resolution,也是后面每个stage会降采样的维度。后面每个stage都会降采样时长宽降到一半,特征数加倍。按照原文及原图描述,划分的每个patch具有 4×4×3=48 维特征。

  • 实际在代码中,首先使用了PatchEmbed模块(这里的PatchEmbed包括上图中的Linear Embedding 和 patch partition层),定义如下:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): # embed_dim就是上图中的C超参数
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

可以看到,实际操作使用了一个卷积层conv2d(3, 96, 4, 4),直接就做了划分patch和编码初始特征的工作,对于输入 x:B×3×224×224 ,经过一层conv2d和LayerNorm得到 x:B×562×96 。然后作为对比,可以选择性地加上每个patch的绝对位置编码,原文实验表示这种做法不好,因此不会采用(ape=false)。最后经过一层dropout,至此,预处理完成。另外,要注意的是,代码和上面流程图并不符,其实在stage 1之前,即预处理完成后,维度已经是 H/4×W/4×C ,stage 1之后已经是 H/8×W/8×2C ,不过在stage 4后不再降采样,得到的还是 H/32×W/32×8C 。

stage处理

我们先梳理整个stage的大体过程,把简单的部分先说了,再深入到复杂得的细节。每个stage,即代码中的BasicLayer,由若干个block组成,而block的数目由depth列表中的元素决定。每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA。在经历完一个stage后,会进行下采样,定义的下采样比较有意思。比如还是 56×56 个patch,四个为一组,分别取每组中的左上,右上、左下、右下堆叠一起,经过一个layernorm,linear层,实现维度下采样、特征加倍的效果。实际上它可以看成一种加权池化的过程。代码如下:

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

在经历完4个stage后,得到的是 (H/32×W/32)×8C 的特征,将其转到 8C×(H/32×W/32) 后,接一个AdaptiveAvgPool1d(1),全局平均池化,得到 8C 特征,最后接一个分类器。

PatchMerging

Block处理

SwinTransformerBlock的结构,由LayerNorm层、windowAttention层(Window MultiHead self -attention, W-MSA)、MLP层以及shiftWindowAttention层(SW-MSA)组成。

上面说到有两种block,block的代码如下:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        # 左图中最下边的LN层layerNorm层
        self.norm1 = norm_layer(dim)
        # W_MSA层或者SW-MSA层,详细的介绍看WindowAttention部分的代码
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # 左图中间部分的LN层
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 左图最上边的MLP层
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # 这里利用shift_size控制是否执行shift window操作
        # 当shift_size为0时,不执行shift操作,对应W-MSA,也就是在每个stage中,W-MSA与SW-MSA交替出现
        # 例如第一个stage中存在两个block,那么第一个shift_size=0就是W-MSA,第二个shift_size不为0
        # 就是SW-MSA
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
#slice() 函数实现切片对象,主要用在切片操作函数里的参数传递。class slice(start, stop[, step])
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
## 上述操作是为了给每个窗口给上索引

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        # 如果需要计算 SW-MSA就需要进行循环移位。
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
#shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

W-MSA

W-MSA比较简单,只要其中shift_size设置为0就是W-MSA。下面跟着代码走一遍过程。

  • 输入: x:B×562×96 , H,W=56
  • 经过一层layerNorm
  • 变形: x:B×56×56×96
  • 直接赋值给shifted_x
  • 调用window_partition函数,输入shifted_xwindow_size=7
  • 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
  • 这个函数对shifted_x做一系列变形,最终变成 82B×7×7×96
  • 返回赋值给x_windows,再变形成 82B×72×96 ,这表示所有图片,每个图片的64个window,每个window内有49个patch。
  • 调用WindowAttention层,这里以它的num_head为3为例。输入参数为x_windowsself.attn_mask,对于W-MSA,attn_mask为None,可以不用管。

WindowAttention代码如下:

代码中使用7×7的windowsize,将feature map分割为不同的window,在每个window中计算自注意力。

Self-attention的计算公式(B为相对位置编码)

绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数。

假设window_size = 2*2即每个窗口有4个token (M=2) ,如图1所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图6所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

坐标变换
坐标变换
相对位置索引求解流程图

最后生成的是相对位置索引,relative_position_index.shape = (M2,M2) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1)∗(2M−1) ,所以相对位置编码(relative position bias table)可以由一个3*3的矩阵表示,如图7所示:这样就根据index对应位置的索引找到table对应位置的值作为相对位置编码。

图7 相对位置编码

图7中的0-8为索引值,每个索引值都对应了 M2 维可学习数据(每个token都要计算 M2 个QK值,每个QK值都要加上对应的相对位置编码)

继续以图6中 M=2 的窗口为例,当计算位置1对应的 M2 个QK值时,应用的relative_position_index = [ 4, 5, 7, 8] (M2)个 ,对应的数据就是图7中位置索引4,5,7,8位置对应的 M2 维数据,即relative_position.shape = (M2∗M2)

相对位置编码在源码WindowAttention中应用,了解原理之后就很容易能够读懂程序:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim # 输入通道的数量
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH  初始化表

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1])  维度=Wh
        coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1])  维度=Ww

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww


        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1

        '''
        后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x\y坐标相加转换为一维偏移的时候
        他们的偏移量是相等的,所以需要对其做乘法操作,进行区分
        '''

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 计算得到相对位置索引
        # relative_position_index.shape = (M2, M2) 意思是一共有这么多个位置
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 

        '''
        relative_position_index注册为一个不参与网络学习的变量
        '''
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        '''
        使用从截断正态分布中提取的值填充输入张量
        self.relative_position_bias_table 是全0张量,通过trunc_normal_ 进行数值填充
        '''
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            N: number of all patches in the window
            C: 输入通过线性层转化得到的维度C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        '''
        x.shape = (num_windows*B, N, C)
        self.qkv(x).shape = (num_windows*B, N, 3C)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
        '''
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        '''
        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
        N = M2 代表patches的数量
        C//num_heads代表Q,K,V的维数
        '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # q乘上一个放缩系数,对应公式中的sqrt(d)
        q = q * self.scale

        # attn.shape = (num_windows*B, num_heads, N, N)  N = M2 代表patches的数量
        attn = (q @ k.transpose(-2, -1))

        '''
        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
        self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
        self.relative_position_index是计算出来不可学习的量
        '''
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        '''
        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 代表patches的数量
        .unsqueeze(0):扩张维度,在0对应的位置插入维度1
        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
        num_windows*B 通过广播机制传播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的维度1会broadcast到数量num_windows*B
        表示所有batch通用一个索引矩阵和相对位置矩阵
        '''
        attn = attn + relative_position_bias.unsqueeze(0)

        # mask.shape = (num_windows, M2, M2)
        # attn.shape = (num_windows*B, num_heads, M2, M2)
        if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        '''
        v.shape = (num_windows*B, num_heads, M2, C//num_heads)  N=M2 代表patches的数量, C//num_heads代表输入的维度
        attn.shape = (num_windows*B, num_heads, M2, M2)
        attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
        '''
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   # B_:num_windows*B  N:M2  C=num_heads*C//num_heads

        #   self.proj = nn.Linear(dim, dim)  dim = C
        #   self.proj_drop = nn.Dropout(proj_drop)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x  # x.shape = (num_windows*B, N, C)  N:窗口中所有patches的数量

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

在上述程序中有一段mask相关程序:

if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

这个部分对应的是Swin Transformer Block 中的SW-MSA

  • 输入 x:82B×72×96 。
  • 产生 QKV ,调用线性层后,得到 82B×72×(96×3) ,拆分给不同的head,得到 82B×72×3×3×32 ,第一个3是 QKV 的3,第二个3是3个head。再permute成 3×82B×3×72×32 ,再拆解成 q,k,v ,每个都是 82B×3×72×32 。表示所有图片的每个图片64个window,每个window对应到3个不同的head,都有一套49个patch、32维的特征。
  • q 归一化
  • qk 矩阵相乘求特征内积,得到 attn:82B×3×72×72
  • 得到相对位置的编码信息relative_position_bias
    • 代码如下:
self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
  • 这里以window_size=3为例,解释以下过程:首先生成 coords:2×3×3 ,就是在一个 3×3 的窗口内,每个位置的 y,x 坐标,而relative_coords为 2×9×9 ,就是9个点中,每个点的 y 或 x 与其他所有点的差值,比如 [0][3][1] 表示3号点(第二行第一个点)与1号点(第一行第二个点)的 y 坐标的差值。然后变形,并让两个坐标分别加上 3−1=2 ,是因为这些坐标值范围 [0,2] ,因此差值的最小值为-2,加上2后从0开始。最后让 y 坐标乘上 2×3−1=5 ,应该是一个trick,调整差值范围。最后将两个维度的差值相加,得到relative_position_index, 32×32 ,为9个点之间两两之间的相对位置编码值,最后用来到self.relative_position_bias_table中寻址,注意相对位置的最大值为 (2M−2)(2M−1) ,而这个table最多有 (2M−1)(2M−1) 行,因此保证可以寻址,得到了一组给多个head使用的相对位置编码信息,这个table是可训练的参数。
  • 回到代码中,得到的relative_position_bias为 3×72×72
  • 将其加到attn上,最后一个维度softmax,dropout
  • 与 v 矩阵相乘,并转置,合并多个头的信息,得到 82B×72×96
  • 经过一层线性层,dropout,返回
  • 返回赋值给attn_windows,变形为 82B×7×7×96
  • 调用window_reverse,打回原状: B×56×56×96
  • 返回给 x ,经过FFN:先加上原来的输入 x 作为residue结构,注意这里用到timmDropPath,并且drop的概率是整个网络结构线性增长的。然后再加上两层mlp的结果。
  • 返回结果 x 。

这样,整个过程就完成了,剩下的就是SW-MSA的一些不同的操作。

  1. 首先将windows进行半个窗口的循环移位,上图中的1, 2步骤,使用torch.roll实现。
  2. 在相同的窗口中计算自注意力,计算结果如下右图所示,window0的结构保存,但是针对window2的计算,其中3与3、6与6的计算生成了attn mask 中window2中的黄色区域,针对windows2中3与6、6与3之间不应该计算自注意力(attn mask中window2的蓝色区域),将蓝色区域mask赋值为-100,经过softmax之后,起作用可以忽略不计。同理window1与window3的计算一致。
  3. 最后再进行循环移位,恢复原来的位置。

原论文图中的Stage和程序中的一个Stage不同:

程序中的BasicLayer为一个Stage,在BasicLayer中调用了上面讲到的SwinTransformerBlock和PatchMerging模块:

class BasicLayer(nn.Module):  # 论文图中每个stage里对应的若干个SwinTransformerBlock
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth # swin_transformer blocks的个数
        self.use_checkpoint = use_checkpoint

        # build blocks  从0开始的偶数位置的SwinTransformerBlock计算的是W-MSA,奇数位置的Block计算的是SW-MSA,且shift_size = window_size//2
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)  # blk = SwinTransformerBlock
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

Part 3 : 不同视觉任务输出

程序中对应的是图片分类任务,经过Part 2 之后的数据通过 norm/avgpool/flatten:

 x = self.norm(x)  # B L C
 x = self.avgpool(x.transpose(1, 2))  # B C 1
 x = torch.flatten(x, 1) # B C

之后通过nn.Linear将特征转化为对应的类别:

self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

应用于其他不同的视觉任务时,只需要将输出进行特定的修改即可。

完整的SwinTransformer程序如下:

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes # 1000
        self.num_layers = len(depths) # [2, 2, 6, 2]  Swin_T 的配置
        self.embed_dim = embed_dim # 96
        self.ape = ape # False
        self.patch_norm = patch_norm # True
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))  # 96*2^3
        self.mlp_ratio = mlp_ratio # 4

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features) # norm_layer = nn.LayerNorm
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)  # 使用self.apply 初始化参数

    def _init_weights(self, m):
        # is_instance 判断对象是否为已知类型
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)  # x.shape = (H//4, W//4, C)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)  # self.pos_drop = nn.Dropout(p=drop_rate)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1) # B C
        return x

    def forward(self, x):
        x = self.forward_features(x)  # x是论文图中Figure 3 a图中最后的输出
        #  self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        x = self.head(x) # x.shape = (B, num_classes)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

补充:有关swin transformer相对位置编码:

VIT

Dosovitskiy et al. An image is worth 16×16 words: transformers for image recognition at scale. In ICLR, 2021

step1 :分割图片

step2 向量化:从九个快变成九个向量

step3:向量线性变换:(linear embedding线性嵌入层)

step4:将位置编码添加到z上:

step4:添加一个cls向量:

step5:只利用cls的输出

按照上面的流程图,一个ViT block可以分为以下几个步骤

(1) patch embedding:例如输入图片大小为224×224,将图片分为固定大小的patch,patch大小为16×16,则每张图像会生成224×224/16×16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196×768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197×768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197×768

(3) LN/multi-head attention/LN:LN输出维度依然是197×768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197×768,如果有12个头(768/12=64),则qkv的维度是197×64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197×768,然后在过一层LN,维度依然是197×768

(4) MLP:将维度放大再缩小回去,197×768放大为197×3072,再缩小变为197×768

一个block之后维度依然和输入相同,都是197×768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 zL0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit需要预训练+微调

• Pretrain the model on Dataset A, fine-tune the model on Dataset B,
and evaluate the model on Dataset B.
• Pretrained on ImageNet (small), ViT is slightly worse than ResNet.
• Pretrained on ImageNet-21K (medium), ViT is comparable to ResNet.
• Pretrained on JFT (large), ViT is slightly better than ResNet.

效果:

Swin Transformer论文解读与思考

论文: https://arxiv.org/abs/2103.14030

github:https://github.com/microsoft/Swin-Transformer

代码详解 https://zhuanlan.zhihu.com/p/384514268

Vision Transformer , Vision MLP 超详细解读 (原理分析+代码解读) (目录)

论文详解:https://space.bilibili.com/1567748478/channel/collectiondetail?sid=32744

Swin Transformer视频讲解:

https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

摘要 :目前transformer应用于CV领域的挑战主要有两个,一个是图片多尺度语义信息的问题,同一个物体在不同图片中的大小尺度变化很大,另外就是难以处理高分辨的图片,如果以pix像素作为序列元素,那么计算成本太大,因此一部分方法是将CNN提取图片特征在送进transformer中 ,或者通过patch,将图片变成一个个的patch。作者提出Swin Transformer 目标是希望作为一种计算机视觉的通用主干网络(因为VIT的提出已经证明了Transformer在CV的可行性),这是一种层级的架构。通过窗口注意力以及转移窗口注意力,不仅降低了计算量,同时层级架构对于不同尺度的信息处理都十分灵活,该架构在图像分类、目标检测、语义分割等任务中表现出色。(对于图像分类、目标检测、语义分割等下游任务,尤其是密集预测任务,多尺度特征是十分必要的)

引言

首先来看作者给出的Swin Transformer 和 VIT结构对比:

VIT的patch固定16*16(可以认为是16倍下采样),多尺度特征处理不好,因为整个过程都是在同一尺度下操作的,出来的特征是单尺度的,优点是全局的特征处理比较强,因为是在全局的尺度进行操作的,但因此他的复杂度跟图像尺寸成平方倍的增长,很难处理目前图像分割检测。再来看 Swin Transformer ,通过图可以看出作者借鉴了CNN的很多设计思路,为了减少序列长度,减低计算量,仅在上面的红框中进行自注意力计算,计算复杂度会跟整张图片的大小成线性关系。另外作者使用基于窗口的注意力的也可以很好的把握物体的全局信息(因为在CV中,一个物体的绝大部分都存在单个windows窗口中,很少会横跨多个窗口),另外CNN网络的如何抓住物体的多尺度特征?是因为pool池化层的存在,每次池化能够增大卷积核看到的感受野。因此作者提出了patch merging,将相邻的四个patch合并成一个大patch(可以认为是加权池化),这样合并出来的一个 大patch就可以看到四个小patch内容感受野增大。有了多尺度特征(4*,8*,16*多尺度特征图)以后,可以接一个FPN头,由于做检测任务,也可以放在unet做分割任务,这就是作者所说的, Swin Transformer 是可以做一个通用骨干网络。

    Transformer的初衷就是更好的理解上下文,如果窗口都是不重叠的,那自注意力真的就变成孤立自注意力,就没有全局建模的能力   
    Swin Transformer 的一个关键设计因素:移动窗口操作。在第l层,通过划分不同的小窗口(实际中是一个窗口有7*7个patch(最小单位),这里示意图以4*4的patch作为一个窗口),自注意力只在窗口中计算 ,就可以有效降低序列长度,从而减少计算复杂度。shift操作可以认为是将l层的窗口整体向右下加移动两个patch所形成的新的窗口,新的特征图进行分割windws以后就有l+1层所示的这些窗口(如下图共九个)了。如果没有shift,那么所有窗口不重叠,在窗口进行自注意力时候,窗口之间无法交互,就无法达到transformer的初衷了(更好的理解上下文),shift后不同窗口的patch就可以进行交互了。再加上一个patch merging操作,不断扩大感受野,到最后几层的时候,每个patch的感受野已经很大了,实际上就可以看到大部分图片了,shift操作以后,就可以看成是全局注意力操作,这样即省内存效果也好。

引言的最后,作者坚信,一个CV和NLP大一统的框架是可以促进两个领域共同发展的,但实际上 Swin Transformer 更多的是利用了CNN的先验知识,从而在计算机视觉领域大杀四方。但是在模型大一统上,也就是 unified architecture 上来说,其实 ViT 还是做的更好的,因为它真的可以什么都不改,什么先验信息都不加,就能让Transformer在两个领域都能用的很好,这样模型不仅可以共享参数,而且甚至可以把所有模态的输入直接就拼接起来,当成一个很长的输入,直接扔给Transformer去做,而不用考虑每个模态的特性

先看结论:

这篇论文提出了 Swin Transformer,它是一个层级式的Transformer,而且它的计算复杂度是跟输入图像的大小呈线性增长的。Swin Transformerr 在 COCO 和 ADE20K上的效果都非常的好,远远超越了之前最好的方法,所以作者说基于此,希望 Swin Transformer 能够激发出更多更好的工作,尤其是在多模态方面。

因为在Swin Transformer 这篇论文里最关键的一个贡献就是基于 Shifted Window 的自注意力,它对很多视觉的任务,尤其是对下游密集预测型的任务是非常有帮助的,但是如果 Shifted Window 操作不能用到 NLP 领域里,其实在模型大一统上论据就不是那么强了,所以作者说接下来他们的未来工作就是要把 Shifted Windows用到 NLP 里面,而且如果真的能做到这一点,那 Swin Transformer真的就是一个里程碑式的工作了,而且模型大一统的故事也就讲的圆满了

方法

主要分为两大块

  • 大概把整体的流程讲了一下,主要就是过了一下前向过程,以及提出的 patch merging 操作是怎么做的
  • 基于 Shifted Window 的自注意力,Swin Transformer怎么把它变成一个transformer block 进行计算

前向过程

  • 假设说有一张224*224*3(ImageNet 标准尺寸)的输入图片
  • 第一步就是像 ViT 那样把图片打成 patch,在 Swin Transformer 这篇论文里,它的 patch size 是4*4,而不是像 ViT 一样16*16,所以说它经过 patch partition 打成 patch 之后,得到图片的尺寸是56*56*48,56就是224/4,因为 patch size 是4,向量的维度48,因为4*4*3,3 是图片的 RGB 通道
  • 打完了 patch ,接下来就要做 Linear Embedding,也就是说要把向量的维度变成一个预先设置好的值,就是 Transformer 能够接受的值,在 Swin Transformer 的论文里把这个超参数设为 c,对于 Swin tiny 网络来说,也就是上图中画的网络总览图,它的 c 是96,所以经历完 Linear Embedding 之后,输入的尺寸就变成了56*56*96,前面的56*56就会拉直变成3136,变成了序列长度,后面的96就变成了每一个token向量的维度,其实 Patch Partition 和 Linear Embedding 就相当于是 ViT 里的Patch Projection 操作,而在代码里也是用一次卷积操作就完成了,
  • 第一部分跟 ViT 其实还是没有区别的,但紧接着区别就来了
  • 首先序列长度是3136,对于 ViT 来说,用 patch size 16*16,它的序列长度就只有196,是相对短很多的,这里的3136就太长了,是目前来说Transformer不能接受的序列长度,所以 Swin Transformer 就引入了基于窗口的自注意力计算,每个窗口按照默认来说,都只有七七四十九个 patch,所以说序列长度就只有49就相当小了,这样就解决了计算复杂度的问题
  • 所以也就是说, stage1中的swin transformer block 是基于窗口计算自注意力的,现在暂时先把 transformer block当成是一个黑盒,只关注输入和输出的维度,对于 Transformer 来说,如果不对它做更多约束的话,Transformer输入的序列长度是多少,输出的序列长度也是多少,它的输入输出的尺寸是不变的,所以说在 stage1 中经过两层Swin Transformer block 之后,输出还是56*56*96
  • 到这其实 Swin Transformer的第一个阶段就走完了,也就是先过一个 Patch Projection 层,然后再过一些 Swin Transformer block,接下来如果想要有多尺寸的特征信息,就要构建一个层级式的 transformer,也就是说需要一个像卷积神经网络里一样,有一个类似于池化的操作

Patch Merging

Patch Merging 其实在之前一些工作里也有用到,它很像 Pixel Shuffle 的上采样的一个反过程,Pixel Shuffle 是 lower level 任务中很常用的一个上采样方式

  • 假如有一个张量, Patch Merging 顾名思义就是把临近的小 patch 合并成一个大 patch,这样就可以起到下采样一个特征图的效果了
  • 这里因为是想下采样两倍,所以说在选点的时候是每隔一个点选一个,也就意味着说对于这个张量来说,每次选的点是1、1、1、1
  • 其实在这里的1、2、3、4并不是矩阵里有的值,而是给它的一个序号,同样序号位置上的 patch 就会被 merge 到一起,这个序号只是为了帮助理解
  • 经过隔一个点采一个样之后,原来的这个张量就变成了四个张量,也就是说所有的1都在一起了,2在一起,3在一起,4在一起,如果原张量的维度是 h * w * c ,当然这里 c 没有画出来,经过这次采样之后就得到了4个张量,每个张量的大小是 h/2、w/2,它的尺寸都缩小了一倍
  • 现在把这四个张量在 c 的维度上拼接起来,也就变成了下图中所画出来的形式,张量的大小就变成了 h/2 * w/2 * 4c,相当于用空间上的维度换了更多的通道数
  • 通过这个操作,就把原来一个大的张量变小了,就像卷积神经网络里的池化操作一样,为了跟卷积神经网络那边保持一致(不论是 VGGNet 还是 ResNet,一般在池化操作降维之后,通道数都会翻倍,从128变成256,从256再变成512),所以这里也只想让他翻倍,而不是变成4倍,所以紧接着又再做了一次操作,就是在 c 的维度上用一个1乘1的卷积,把通道数降下来变成2c,通过这个操作就能把原来一个大小为 h*w*c 的张量变成 h/2 * w/2 *2c 的一个张量,也就是说空间大小减半,但是通道数乘2,这样就跟卷积神经网络完全对等起来了

这里其实会发现,特征图的维度真的跟卷积神经网络好像,因为如果回想残差网络的多尺寸的特征,就是经过每个残差阶段之后的特征图大小也是56*56、28*28、14*14,最后是7*7

而且为了和卷积神经网络保持一致,Swin Transformer这篇论文并没有像 ViT 一样使用 CLS token,ViT 是给刚开始的输入序列又加了一个 CLS token,所以这个长度就从196变成了197,最后拿 CLS token 的特征直接去做分类,但 Swin Transformer 没有用这个 token,它是像卷积神经网络一样,在得到最后的特征图之后用global average polling,就是全局池化的操作,直接把7*7就取平均拉直变成1了

作者这个图里并没有画,因为 Swin Transformer的本意并不是只做分类,它还会去做检测和分割,所以说它只画了骨干网络的部分,没有去画最后的分类头或者检测头,但是如果是做分类的话,最后就变成了1*768,然后又变成了1*1,000

所以看完整个前向过程之后,就会发现 Swin Transformer 有四个 stage,还有类似于池化的 patch merging 操作,自注意力还是在小窗口之内做的以及最后还用的是 global average polling,所以说 Swin Transformer 这篇论文真的是把卷积神经网络和 Transformer 这两系列的工作完美的结合到了一起,也可以说它是披着Transformer皮的卷积神经网络

主要贡献

这篇论文的主要贡献就是基于窗口或者移动窗口的自注意力,这里作者又写了一段研究动机,就是为什么要引入窗口的自注意力,其实跟之前引言里说的都是一个事情,就是说全局自注意力的计算会导致平方倍的复杂度,同样当去做视觉里的下游任务,尤其是密集预测型的任务,或者说遇到非常大尺寸的图片时候,这种全局算自注意力的计算复杂度就非常贵了,所以就用窗口的方式去做自注意力

重点:窗口注意力

原图片会被平均的分成一些没有重叠的窗口,拿第一层之前的输入来举例,它的尺寸就是56*56*96,也就说有一个维度是56*56张量,然后把它切成一些不重叠的方格(论文中使用7*7的patch作为一个window窗口)

  • 现在所有自注意力的计算都是在这些小窗口里完成的,就是说序列长度永远都是7*7=49
  • 原来大的整体特征图到底里面会有多少个窗口呢?其实也就是每条边56/7就8个窗口,也就是说一共会有8*8等于64个窗口,就是说会在这64个窗口里分别去算它们的自注意力

基于窗口的自注意力模式的计算复杂度计算:

  • 如果现在有一个输入,自注意力首先把它变成 q k v 三个向量,这个过程其实就是原来的向量分别乘了三个系数矩阵
  • 一旦得到 query 和 k 之后,它们就会相乘,最后得到 attention,也就是自注意力的矩阵
  • 有了自注意力之后,就会和 value 做一次乘法,也就相当于是做了一次加权
  • 最后因为是多头自注意力,所以最后还会有一个 projection layer,这个投射层会把向量的维度投射到我们想要的维度

如果这些向量都加上它们该有的维度,也就是说刚开始输入是 h*w*c

  • 公式(1)对应的是标准的多头自注意力的计算复杂度
  • 每一个图片大概会有 h*w 个 patch,在刚才的例子里,h 和 w 分别都是56,c 是特征的维度
  • 公式(2)对应的是基于窗口的自注意力计算的复杂度,这里的 M 就是刚才的7,也就是说一个窗口的某条边上有多少个patch

基于窗口的自注意力计算复杂度又是如何得到的呢?

  • 因为在每个窗口里算的还是多头自注意力,所以可以直接套用公式(1),只不过高度和宽度变化了,现在高度和宽度不再是 h * w,而是变成窗口有多大了,也就是 M*M,也就是说现在 h 变成了 M,w 也是 M,它的序列长度只有 M * M 这么大
  • 所以当把 M 值带入到公式(1)之后,就得到计算复杂度是4 * M^2 * c^2 + 2 * M^4 * c,这个就是在一个窗口里算多头自注意力所需要的计算复杂度
  • 那我们现在一共有 h/M * w/M 个窗口,现在用这么多个窗口乘以每个窗口所需要的计算复杂度就能得到公式(2)了

对比公式(1)和公式(2),虽然这两个公式前面这两项是一样的,只有后面从 (h*w)^2变成了 M^2 * h * w,看起来好像差别不大,但其实如果仔细带入数字进去计算就会发现,计算复杂的差距是相当巨大的,因为这里的 h*w 如果是56*56的话, M^2 其实只有49,所以是相差了几十甚至上百倍的

这种基于窗口计算自注意力的方式虽然很好地解决了内存和计算量的问题,但是窗口和窗口之间没有通信,这样就达不到全局建模了,也就文章里说的会限制模型的能力,所以最好还是要有一种方式能让窗口和窗口之间互相通信起来,这样效果应该会更好,因为具有上下文的信息,所以作者就提出移动窗口的方式

移动窗口:

移动窗口就是把原来的窗口往右下角移动一半窗口的距离,如果Transformer是上下两层连着做这种操作,先是 window再是 shifted window 的话,就能起到窗口和窗口之间互相通信的目的了

所以说在 Swin Transformer里, transformer block 的安排是有讲究的,每次都是先要做一次基于窗口的多头自注意力,然后再做一次基于移动窗口的多头自注意力,这样就达到了窗口和窗口之间的互相通信。如下图所示

  • 每次输入先进来之后先做一次 Layernorm,然后做窗口的多头自注意力,然后再过 Layernorm 过 MLP,第一个 block 就结束了
  • 这个 block 结束以后,紧接着做一次Shifted window,也就是基于移动窗口的多头自注意力,然后再过 MLP 得到输出
  • 这两个 block 加起来其实才算是 Swin Transformer 一个基本的计算单元,这也就是为什么stage1、2、3、4中的 swin transformer block 为什么是 *2、*2、*6、*2,也就是一共有多少层 Swin Transformer block 的数字总是偶数,因为它始终都需要两层 block连在一起作为一个基本单元,所以一定是2的倍数

到此,Swin Transformer整体的故事和结构就已经讲完了,主要的研究动机就是想要有一个层级式的 Transformer,为了这个层级式,所以介绍了 Patch Merging 的操作,从而能像卷积神经网络一样把 Transformer 分成几个阶段,为了减少计算复杂度,争取能做视觉里密集预测的任务,所以又提出了基于窗口和移动窗口的自注意力方式,也就是连在一起的两个Transformer block,最后把这些部分加在一起,就是 Swin Transformer 的结构

提高移动窗口的计算效率:

  • 一个是怎样提高移动窗口的计算效率,他们采取了一种非常巧妙的 masking(掩码)的方式
  • 另外一个点就是这篇论文里没有用绝对的位置编码,而是用相对的位置编码

masking(掩码)的方式计算移动窗口自注意力:为什么需要使用?

为了提高计算效率,因为如果直接计算右下图的九个窗口的自注意力,不同大小的窗口无法合并成一个batch进行计算。

  • 上图是一个基础版本的移动窗口,就是把左边的窗口模式变成了右边的窗口方式
  • 虽然这种方式已经能够达到窗口和窗口之间的互相通信了,但是会发现一个问题,就是原来计算的时候,特征图上只有四个窗口,但是做完移动窗口操作之后得到了9个窗口,窗口的数量增加了,而且每个窗口里的元素大小不一,比如说中间的窗口还是4*4,有16个 patch,但是别的窗口有的有4个 patch,有的有8个 patch,都不一样了,如果想做快速运算,就是把这些窗口全都压成一个 patch直接去算自注意力,就做不到了,因为窗口的大小不一样
  • 有一个简单粗暴的解决方式就是把这些小窗口周围再 pad 上0 ,把它照样pad成和中间窗口一样大的窗口,这样就有9个完全一样大的窗口,这样就还能把它们压成一个batch,就会快很多
  • 但是这样的话,无形之中计算复杂度就提升了,因为原来如果算基于窗口的自注意力只用算4个窗口,但是现在需要去算9个窗口,复杂度一下提升了两倍多,所以还是相当可观的
  • 那怎么能让第二次移位完的窗口数量还是保持4个,而且每个窗口里的patch数量也还保持一致呢?作者提出了一个非常巧妙的掩码方式,如下图所示

上图是说,当通过普通的移动窗口方式,得到9个窗口之后,现在不在这9个窗口上算自注意力,先再做一次循环移位( cyclic shift )

  • 经过这次循环移位之后,原来的窗口(虚线)就变成了现在窗口(实线)的样子,那如果在大的特征图上再把它分成四宫格的话,我在就又得到了四个窗口,意思就是说移位之前的窗口数也是4个,移完位之后再做一次循环移位得到窗口数还是4个,这样窗口的数量就固定了,也就说计算复杂度就固定了
  • 但是新的问题就来了,虽然对于移位后左上角的窗口(也就是移位前最中间的窗口)来说,里面的元素都是互相紧挨着的,他们之间可以互相两两做自注意力,但是对于剩下几个窗口来说,它们里面的元素是从别的很远的地方搬过来的,所以他们之间,按道理来说是不应该去做自注意力,也就是说他们之间不应该有什么太大的联系
  • 解决这个问题就需要一个很常规的操作,也就是掩码操作,这在Transformer过去的工作里是层出不穷,很多工作里都有各式各样的掩码操作
  • 在 Swin Transformer这篇论文里,作者也巧妙的设计了几种掩码的方式,从而能让一个窗口之中不同的区域之间也能用一次前向过程,就能把自注意力算出来,但是互相之间都不干扰,也就是后面的 masked Multi-head Self Attention(MSA)
  • 算完了多头自注意力之后,还有最后一步就是需要把循环位移再还原回去,也就是说需要把A、B、C再还原到原来的位置上去,原因是还需要保持原来图片的相对位置大概是不变的,整体图片的语义信息也是不变的,如果不把循环位移还原的话,那相当于在做Transformer的操作之中,一直在把图片往右下角移,不停的往右下角移,这样图片的语义信息很有可能就被破坏掉了
  • 所以说整体而言,上图介绍了一种高效的、批次的计算方式比如说本来移动窗口之后得到了9个窗口,而且窗口之间的patch数量每个都不一样,为了达到高效性,为了能够进行批次处理,先进行一次循环位移,把9个窗口变成4个窗口,然后用巧妙的掩码方式让每个窗口之间能够合理地计算自注意力,最后再把算好的自注意力还原,就完成了基于移动窗口的自注意力计算

掩码操作如何实现 :

作者通过这种巧妙的循环位移的方式和巧妙设计的掩码模板,从而实现了只需要一次前向过程,就能把所有需要的自注意力值都算出来,而且只需要计算4个窗口,也就是说窗口的数量没有增加,计算复杂度也没有增加,非常高效的完成了这个任务

作者给出了不同窗口的不同掩码矩阵:

上图示例的Cyclic Shifting方法,可以保持面向计算的window数量保持不变(还是2X2),在window内部通过attention mask来计算子window中的自注意力。

Swin Transformer的几个变体

  • Swin Tiny
  • Swin Small
  • Swin Base
  • Swin Large

Swin Tiny的计算复杂度跟 ResNet-50 差不多,Swin Small 的复杂度跟 ResNet-101 是差不多的,这样主要是想去做一个比较公平的对比

这些变体之间有哪些不一样呢?,其实主要不一样的就是两个超参数

  • 一个是向量维度的大小 c
  • 另一个是每个 stage 里到底有多少个 transform block

这里其实就跟残差网络就非常像了,残差网络也是分成了四个 stage,每个 stage 有不同数量的残差块

实验

分类

首先是分类上的实验,这里一共说了两种预训练的方式

  • 第一种就是在正规的ImageNet-1K(128万张图片、1000个类)上做预训练
  • 第二种方式是在更大的ImageNet-22K(1,400万张图片、2万多个类别)上做预训练

当然不论是用ImageNet-1K去做预训练,还是用ImageNet-22K去做预训练,最后测试的结果都是在ImageNet-1K的测试集上去做的,结果如下表所示

  • 上半部分是ImageNet-1K预训练的模型结果
  • 下半部分是先用ImageNet-22K去预训练,然后又在ImageNet-1K上做微调,最后得到的结果
  • 在表格的上半部分,作者先是跟之前最好的卷积神经网络做了一下对比,RegNet 是之前 facebook 用 NASA 搜出来的模型,EfficientNet 是 google 用NASA 搜出来的模型,这两个都算之前表现非常好的模型了,他们的性能最高会到 84.3
  • 接下来作者就写了一下之前的 Vision Transformer 会达到什么效果,对于 ViT 来说,因为它没有用很好的数据增强,而且缺少偏置归纳,所以说它的结果是比较差的,只有70多
  • 换上 DeiT 之后,因为用了更好的数据增强和模型蒸馏,所以说 DeiT Base 模型也能取得相当不错的结果,能到83.1
  • 当然 Swin Transformer 能更高一些,Swin Base 最高能到84.5,稍微比之前最好的卷积神经网络高那么一点点,就比84.3高了0.2
  • 虽然之前表现最好的 EfficientNet 的模型是在 600*600 的图片上做的,而 Swin Base 是在 384*384 的图片上做的,所以说 EfficientNet 有一些优势,但是从模型的参数和计算的 FLOPs 上来说 EfficientNet 只有66M,而且只用了 37G 的 FLOPs,但是 Swin Transformer 用了 88M 的模型参数,而且用了 47G 的 FLOPs,所以总体而言是伯仲之间
  • 表格的下半部分是用 ImageNet-22k 去做预训练,然后再在ImageNet-1k上微调最后得到的结果
  • 这里可以看到,一旦使用了更大规模的数据集,原始标准的 ViT 的性能也就已经上来了,对于 ViT large 来说它已经能得到 85.2 的准确度了,已经相当高了
  • 但是 Swin Large 更高,Swin Large 最后能到87.3,这个是在不使用JFT-300M,就是特别大规模数据集上得到的结果,所以还是相当高的

目标检测

  • 表2(a)中测试了在不同的算法框架下,Swin Transformer 到底比卷积神经网络要好多少,主要是想证明 Swin Transformer 是可以当做一个通用的骨干网络来使用的,所以用了 Mask R-CNN、ATSS、RepPointsV2 和SparseR-CNN,这些都是表现非常好的一些算法,在这些算法里,过去的骨干网络选用的都是 ResNet-50,现在替换成了 Swin Tiny
  • Swin Tiny 的参数量和 FLOPs 跟 ResNet-50 是比较一致的,从后面的对比里也可以看出来,所以他们之间的比较是相对比较公平的
  • 可以看到,Swin Tiny 对 ResNet-50 是全方位的碾压,在四个算法上都超过了它,而且超过的幅度也是比较大的
  • 接下来作者又换了一个方式做测试,现在是选定一个算法,选定了Cascade Mask R-CNN 这个算法,然后换更多的不同的骨干网络,比如 DeiT-S、ResNet-50 和 ResNet-101,也分了几组,结果如上图中表2(b)所示
  • 可以看出,在相似的模型参数和相似的 Flops 之下,Swin Transformer 都是比之前的骨干网络要表现好的
  • 接下来作者又做了第三种测试的方式,如上图中的表2(c)所示,就是系统层面的比较,这个层面的比较就比较狂野了,就是现在追求的不是公平比较,什么方法都可以上,可以使用更多的数据,可以使用更多的数据增强,甚至可以在测试的使用 test time augmentation(TTA)的方式
  • 可以看到,之前最好的方法 Copy-paste 在 COCO Validation Set上的结果是55.9,在 Test Set 上的结果是56,而这里如果跟最大的 Swin Transformer–Swin Large 比,它的结果分别能达到58和58.7,这都比之前高了两到三个点

语义分割

  • 上图表3里可以看到之前的方法,一直到 DeepLab V3、ResNet 其实都用的是卷积神经网络,之前的这些方法其实都在44、45左右徘徊
  • 但是紧接着 Vision Transformer 就来了,那首先就是 SETR 这篇论文,他们用了 ViT Large,所以就取得了50.3的这个结果
  • Swin Transformer Large也取得了53.5的结果,就刷的更高了
  • 其实作者这里也有标注,就是有两个“+”号的,意思是说这些模型是在ImageNet-22K 数据集上做预训练,所以结果才这么好

消融实验

实验结果如下图所示

  • 上图中表4主要就是想说一下移动窗口以及相对位置编码到底对 Swin Transformer 有多有用
  • 可以看到,如果光分类任务的话,其实不论是移动窗口,还是相对位置编码,它的提升相对于基线来说,也没有特别明显,当然在ImageNet的这个数据集上提升一个点也算是很显着了
  • 但是他们更大的帮助,主要是出现在下游任务里,就是 COCO 和 ADE20K 这两个数据集上,也就是目标检测和语义分割这两个任务上
  • 可以看到,用了移动窗口和相对位置编码以后,都会比之前大概高了3个点左右,提升是非常显着的,这也是合理的,因为如果现在去做这种密集型预测任务的话,就需要特征对位置信息更敏感,而且更需要周围的上下文关系,所以说通过移动窗口提供的窗口和窗口之间的互相通信,以及在每个 Transformer block都做更准确的相对位置编码,肯定是会对这类型的下游任务大有帮助的

总结

虽然前面已经说了很多 Swin Transformer 的影响力啊已经这么巨大了,但其实他的影响力远远不止于此,论文里这种对卷积神经网络,对 Transformer,还有对 MLP 这几种架构深入的理解和分析是可以给更多的研究者带来思考的,从而不仅可以在视觉领域里激发出更好的工作,而且在多模态领域里,相信它也能激发出更多更好的工作

BPR:用于实例分割的边界Patch优化(CVPR2021)

 

Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation

代码链接:https://github.com/tinyalpha/BPR

后处理分割结果,效果是即插即用后处理模块当年的sota通过将 BPR 框架应用于 PolyTransform + SegFix 基线,我们在 Cityscapes 排行榜上排名第一。

从目前的排名来说(22.09.23),排名第五,与top1相差不到2个百分点,而 BPR后处理使得PolyTransform + SegFix的效果提升了1.5个百分点。 相比于MASK-RCNN提升了4.2个百分点。

CVPR21上一篇关于实例分割的文章。对于Mask RCNN来说,其最终得到的mask分辨率太低,因此还原到原尺寸的时候,一些boundary信息就显得非常粗糙,导致预测生成的mask效果不尽如人意。而且处于boundary的pixel本身数量相比于整张image来说很少,同时本身难以做分类。现有的一些方法试图提升boundary quality,但预测mask边界这个task本身的复杂度和segmentation很接近了,因此开销较大。

因此本文作者提出了一种crop-and-refine的策略。首先通过经典的实例分割网络(如Mask RCNN)得到coarse mask。随后在mask的boundary出提取出一系列的patch,随后将这些patch送入一个Refinement Network,这个Refinement Network负责做二分类的语义分割,进而对boundary处的patch进行优化,整个后处理的优化网络称为BPR(Boundary Patch Refinement)。该网络可以解决传统Mask RCNN预测的mask的边界粗糙的问题。

本文的核心就是在Mask RCNN一类的网络给出coarse mask后,如何设计Refine Network来对这个粗糙 mask 的边界进行优化,进而得到resolution更高,boundary quality更好的mask。

给定一个coarse mask(上图a),首先需要决定这个mask的哪些部分要做refine。这里作者提出了一种sliding-window式的方法提取到boundary处的一系列patch(上图b)。具体来说,就是在mask边界处密集assign正方形的bounding box,这些box内部囊括了boundary pixel。随后,由于这些box有的overlap太大导致redundant(冗余),这里采用NMS进行过滤(上图c),以实现速度和精度的trade-off(平衡)。

随后这些survive下来的image patch(上图d)和mask patch(上图e)都resize到同一尺寸,一起喂入Refinement Network。这里作者argue说一定要喂入mask patch,因为一旦拥有mask patch的location和semantic信息,这个refinement network就不再需要学习instance-level semantic(实例类别信息,比如这个image patch属于哪个类别)了。所以,refinement network只需要学习boundary处的hard pixel,并把它们正确分类。

关于Refinement Network,其任务是为每一个提取出来的boundary patch独立地做二分类语义分割,任何的语义分割模型都可以搬过来做这个task。输入的通道数为4(RGB+mask),输出通道数为2(BG or FG),这里作者采用了HRNetV2(CVPR 2019),这种各种level feature不断做融合的网络可以maintain高分辨率的representation。通过合理的增加input size,boundary batch就可以得到比之前方法更高的resolution。

HRNetV2 网络结构

在对每个patch独立地refine以后,需要将它们reassemble(组装)到coarse mask上面。有的相邻的patch可能存在overlap的情况,最终的结果是取平均,以0.5作为阈值判断某个pixel属于前景或是背景。

Experiment

这里的指标是AP (Average precision):指的是PR曲线的面积(AP就是平均精准度,简单来说就是对PR曲线上的Precision值求均值。)对于实例分割的评价指标:使用AP评价指标

实例分割和目标检测mAP计算时除了IOU计算方式(实例分割是mask间的IOU)不同,其他都是一样的.

对于一个二分类任务,二分类器的预测结果可分为以下4类:

二分类器的结果可分为4类

Precision的定义为:

Recall的定义为: 

Precision从预测结果角度出发,描述了二分类器预测出来的正例结果中有多少是真实正例,即该二分类器预测的正例有多少是准确的;Recall从真实结果角度出发,描述了测试集中的真实正例有多少被二分类器挑选了出来,即真实的正例有多少被该二分类器召回。

逐步降低二分类器预测正例的门槛,则每次可以计算得到当前的Precision和Recall。以Recall作为横轴,Precision作为纵轴可以得到Precision-Recall曲线图,简称为P-R图。

详细解释:目标检测/实例分割中 AP 和 mAP 的混淆指标

preview

首先通过实验证明了将mask patch一并作为输入的重要性:

patch size、不同的patch extraction策略,input size对结果的影响:

RefineNet的选取,NMS的阈值:

Cityscape上与其他方法的比较:PolyTransform + SegFix baseline,达到最高的AP。

迁移到其他model上面的结果 and coco数据集上的结果

Mask-RCNN论文

论文:http://cn.arxiv.org/pdf/1703.06870v3

代码:https://github.com/facebookresearch/maskrcnn-benchmark

B站网络详解 FPN

Introduction

我们提出了一个简单、灵活、通用的实例分割框架,称为Mask R-CNN。我们的方法能够有效检测图像中的目标,同时为每个实例生成高质量的分割掩码。Mask R-CNN通过添加一个预测对象掩码的分支,与现有的边框识别分支并行,扩展了之前的Faster R-CNN。Mask R-CNN的训练很简单,只为Faster R-CNN增加了一小部分开销,运行速度为5帧/秒。此外,Mask R-CNN很容易泛化到其他任务,如人体姿态估计。我们展示了Mask R-CNN在COCO挑战赛的实例分割、目标检测和人物关键点检测任务上的最优结果。在不使用花哨技巧的情况下,Mask R-CNN在各项任务上都优于现有的单一模型,包括COCO 2016挑战赛的冠军。我们希望Mask R-CNN能够成为一个坚实的基线,并有助于简化未来实例识别的研究。

Fast/Faster R-CNN和Fully Convolutional Network(FCN)框架极大地推动了计算机视觉领域中目标检测和语义分割等方向的发展。这些方法的概念很直观,具有良好的灵活性和鲁棒性,并且能够快速训练和推理。我们这项工作的目标是为实例分割任务开发一个相对可行的框架。

实例分割具有一定的挑战性,因为它需要正确检测图像中的所有对象,同时还要精确分割每个实例。因此,它结合了目标检测和语义分割等计算机视觉任务中的元素。目标检测旨在对单个物体进行分类,并使用边框对每个物体进行定位。语义分割旨在将每个像素归类到一组固定的类别,而不区分对象实例。鉴于此,人们可能会认为需要一套复杂的方法才能获得良好的结果。然而,我们证明了一个令人惊讶的事实:简单、灵活、快速的系统也可以超越现有的最先进的实例分割模型。

我们的方法称为Mask R-CNN,通过在每个RoI(感兴趣区域,Region of Interest)上添加一个预测分割掩码的分支来扩展Faster R-CNN,并与现有的用于分类和边框回归的分支并行。掩码分支是应用于每个RoI的一个小FCN,以像素到像素的方式预测分割掩码,并且只会增加较小的计算开销。Mask R-CNN是基于Faster R-CNN框架而来的,易于实现和训练,有助于广泛、灵活的架构设计。

原则上,Mask R-CNN是Faster R-CNN的直观扩展,但正确构建掩码分支对于获得好的结果至关重要。最重要的是,Faster R-CNN的设计没有考虑网络输入和输出之间的像素到像素的对齐。这一点在RoIPool(处理实例的核心操作)如何执行粗空间量化来提取特征上表现得最为明显。为了修正错位,我们提出了一个简单的、没有量化的层,称为RoIAlign,它忠实地保留了精确的空间位置。尽管这看起来是一个很小的变化,但是RoIAlign有很大的影响:它将掩码精度提高了10%-50%,在更严格的localization指标下显示出更大的收益。其次,我们发现有必要将掩码和类别预测解耦:我们为每个类别独立预测一个二进制掩码,类别之间没有竞争,并依靠网络的RoI分类分支来预测类别。相比之下,FCN通常执行逐像素的多分类操作,将分割和分类耦合在一起,我们的实验结果表明这种方法的实例分割效果不佳。

在不使用花哨技巧的情况下,Mask R-CNN在COCO实例分割任务上就超越了之前的所有SOTA单模型,包括COCO 2016比赛的冠军。作为副产品,我们的方法在COCO目标检测任务上也表现出色。在消融实验中,我们评估了多个基本实例,这使我们能够证明Mask R-CNN的鲁棒性,并分析其核心因素的影响。

我们的模型可以在GPU上以每帧约200ms的速度运行,在一台8-GPU的机器上进行COCO训练需要1-2天。我们相信,快速的训练和测试,以及框架的灵活性和准确性,将有利于未来实例分割的研究。

最后,我们通过COCO关键点数据集上的人体姿态估计任务展示了Mask R-CNN框架的通用性。通过将每个关键点视为一个独热二进制掩码,只需对Mask R-CNN稍加修改,即可用于检测特定实例的姿态。Mask R-CNN超越了COCO 2016关键点检测比赛的冠军,并且能够以5帧/秒的速度运行。因此,Mask R-CNN可以被更广泛地视为一个实例识别的灵活框架,并且很容易泛化到其他更复杂的任务上。

模型方法

Mask R-CNN方法很简单:Faster R-CNN对每个候选对象有两个输出,一个是类别标签,另一个是边框偏移量。在此基础上,我们添加了第三个分支,用于输出分割掩码。因此,Mask R-CNN是一个自然且直观的想法。但是掩码输出不同于类别和边框输出,需要提取更精细的对象空间布局。接下来,我们介绍了Mask R-CNN的关键元素,包括像素到像素对齐,这是Fast/Faster R-CNN所缺失的部分。

用于实例分割的Mask R-CNN框架

RoIAlign:虚线网格表示特征映射图,实线边框表示RoI(Region of Interest),点表示每个边框中的4个采样点。RoIAlign通过双线性插值从特征映射图上的相邻网格点计算每个采样点的值。

  • Network Architecture: 为了表述清晰,有两种分类方法
  1. 使用了不同的backbone:resnet-50,resnet-101,resnext-50,resnext-101;
  2. 使用了不同的head Architecture:Faster RCNN使用resnet50时,从Block 4导出特征供RPN使用,这种叫做ResNet-50-C4
  3. 作者使用除了使用上述这些结构外,还使用了一种更加高效的backbone:FPN(特征金字塔网络)
Head架构:我们扩展了两个现有的Faster R-CNN Head。
  • Mask R-CNN基本结构:与Faster RCNN采用了相同的two-state结构:首先是通过一阶段网络找出RPN,然后对RPN找到的每个RoI进行分类、定位、并找到binary mask。这与当时其他先找到mask然后在进行分类的网络是不同的。
  • Mask R-CNN的损失函数L = L{_{cls}} + L{_{box}} + L{_{mask}} (当然了,你可以在这里调权以实现更好的效果)
  • Mask的表现形式(Mask Representation):因为没有采用全连接层并且使用了RoIAlign,我们最终是在一个小feature map上做分割。
  • RoIAlign:RoIPool的目的是为了从RPN网络确定的ROI中导出较小的特征图(a small feature map,eg 7×7),ROI的大小各不相同,但是RoIPool后都变成了7×7大小。RPN网络会提出若干RoI的坐标以[x,y,w,h]表示,然后输入RoI Pooling,输出7×7大小的特征图供分类和定位使用。问题就出在RoI Pooling的输出大小是7×7上,如果RON网络输出的RoI大小是8*8的,那么无法保证输入像素和输出像素是一一对应,首先他们包含的信息量不同(有的是1对1,有的是1对2),其次他们的坐标无法和输入对应起来。这对分类没什么影响,但是对分割却影响很大。RoIAlign的输出坐标使用插值算法得到,不再是简单的量化;每个grid中的值也不再使用max,同样使用差值算法。

Implementation Details

使用Fast/Faster相同的超参数,同样适用于Mask RCNN

  • Training:

1、与之前相同,当IoU与Ground Truth的IoU大于0.5时才会被认为有效的RoI,L{_{mask}}只把有效RoI计算进去。

2、采用image-centric training,图像短边resize到800,每个GPU的mini-batch设置为2,每个图像生成N个RoI,在使用ResNet-50-C4 作为backbone时,N=64,在使用FPN作为backbone时,N=512。作者服务器中使用了8块GPU,所以总的minibatch是16, 迭代了160k次,初始lr=0.02,在迭代到120k次时,将lr设定到 lr=0.002,另外学习率的weight_decay=0.0001 momentum = 0.9。如果是resnext,初始lr=0.01,每个GPU的mini-batch是1。

3、RPN的anchors有5种scale,3种ratios。为了方便剥离、如果没有特别指出,则RPN网络是单独训练的且不与Mask R-CNN共享权重。但是在本论文中,RPN和Mask R-CNN使用一个backbone,所以他们的权重是共享的。(Ablation Experiments 为了方便研究整个网络中哪个部分其的作用到底有多大,需要把各部分剥离开)

  • Inference:在测试时,使用ResNet-50-C4作为 backbone情况下proposal number=300,使用FPN作为 backbone时proposal number=1000。然后在这些proposal上运行bbox预测,接着进行非极大值抑制。mask分支只应用在得分最高的100个proposal上。顺序和train是不同的,但这样做可以提高速度和精度。mask 分支对于每个roi可以预测k个类别,但是我们只要背景和前景两种,所以只用k-th mask,k是根据分类分支得到的类型。然后把k-th mask resize成roi大小,同时使用阈值分割(threshold=0.5)二值化

Experiments

Main Results

在下图中可以明显看出,FCIS的分割结果中都会出现一条竖着的线(systematic artifacts),这线主要出现在物体重的部分,作者认为这是FCIS架构的问题,无法解决的。但是在Mask RCNN中没有出现。

Ablation Experiments(剥离实验)

  • Architecture:
    从table 2a中看出,Mask RCNN随着增加网络的深度、采用更先进的网络,都可以提高效果。注意:并不是所有的网络都是这样。
  • Multinomial vs. Independent Masks:(mask分支是否进行类别预测)从table 2b中可以看出,使用sigmoid(二分类)和使用softmax(多类别分类)的AP相差很大,证明了分离类别和mask的预测是很有必要的
  • Class-Specific vs. Class-Agnostic Masks:目前使用的mask rcnn都使用class-specific masks,即每个类别都会预测出一个mxm的mask,然后根据类别选取对应的类别的mask。但是使用Class-Agnostic Masks,即分割网络只输出一个mxm的mask,可以取得相似的成绩29.7vs30.3
  • RoIAlign:tabel 2c证明了RoIAlign的性能
  • Mask Branch:tabel 2e,FCN比MLP性能更好

Bounding Box Detection Results    

  • Mask RCNN精度高于Faster RCNN
  • Faster RCNN使用RoI Align的精度更高
  • Mask RCNN的分割任务得分与定位任务得分相近,说明Mask RCNN已经缩小了这部分差距。

Timing 

  • Inference:195ms一张图片,显卡Nvidia Tesla M40。其实还有速度提升的空间,比如减少proposal的数量等。
  • Training:ResNet-50-FPN on COCO trainval35k takes 32 hours  in our synchronized 8-GPU implementation (0.72s per 16-image mini-batch),and 44 hours with ResNet-101-FPN。

Mask R-CNN for Human Pose Estimation

让Mask R-CNN预测k个masks,每个mask对应一个关键点的类型,比如左肩、右肘,可以理解为one-hot形式。

  • 使用cross entropy loss,可以鼓励网络只检测一个关键点;
  • ResNet-FPN结构
  • 训练了90k次,最开始lr=0.02,在迭代60k次时,lr=0.002,80k次时变为0.0002

MICCAI 2022:基于 MLP 的快速医学图像分割网络 UNeXt

论文地址: https://arxiv.org/abs/2203.04967

github:https://github.com/jeya-maria-jose/UNeXt-pytorch

UnetX 网络结构

Datasets

  1. ISIC 2018 – Link
  2. BUSI – Link

MICCAI 2022:基于 MLP 的快速医学图像分割网络 UNeXt

前言

最近 MICCAI 2022 的论文集开放下载了,地址:https://link.springer.com/book/10.1007/978-3-031-16443-9 ,每个部分的内容如下所示:

Part I: Brain development and atlases; DWI and tractography; functional brain networks; neuroimaging; heart and lung imaging; dermatology;

Part II: Computational (integrative) pathology; computational anatomy and physiology; ophthalmology; fetal imaging;

Part III: Breast imaging; colonoscopy; computer aided diagnosis;

Part IV: Microscopic image analysis; positron emission tomography; ultrasound imaging; video data analysis; image segmentation I;

Part V: Image segmentation II; integration of imaging with non-imaging biomarkers;

Part VI: Image registration; image reconstruction;

Part VII: Image-Guided interventions and surgery; outcome and disease prediction; surgical data science; surgical planning and simulation; machine learning – domain adaptation and generalization;

Part VIII: Machine learning – weakly-supervised learning; machine learning – model interpretation; machine learning – uncertainty; machine learning theory and methodologies.

其中关于分割有两个部分,Image segmentation I 在 Part IV, 而 Image segmentation II 在 Part V。

随着医学图像的解决方案变得越来越适用,我们更需要关注使深度网络轻量级、快速且高效的方法。具有高推理速度的轻量级网络可以被部署在手机等设备上,例如 POCUS(point-of-care ultrasound)被用于检测和诊断皮肤状况。这就是 UNeXt 的动机。

方法概述

之前我们解读过基于 Transformer 的 U-Net 变体,近年来一直是领先的医学图像分割方法,但是参数量往往不乐观,计算复杂,推理缓慢。这篇文章提出了基于卷积多层感知器(MLP)改进 U 型架构的方法,可以用于图像分割。设计了一个 tokenized MLP 块有效地标记和投影卷积特征,使用 MLPs 来建模表示。这个结构被应用到 U 型架构的下两层中(这里我们假设纵向一共五层)。文章中提到,为了进一步提高性能,建议在输入到 MLP 的过程中改变输入的通道,以便专注于学习局部依赖关系特征。还有额外的设计就是跳跃连接了,并不是我们主要关注的地方。最终,UNeXt 将参数数量减少了 72 倍,计算复杂度降低了 68 倍,推理速度提高了 10 倍,同时还获得了更好的分割性能,如下图所示。

UNeXt 架构

UNeXt 的设计如下图所示。纵向来看,一共有两个阶段,普通的卷积和 Tokenized MLP 阶段。其中,编码器和解码器分别设计两个 Tokenized MLP 块。每个编码器将分辨率降低两倍,解码器工作相反,还有跳跃连接结构。每个块的通道数(C1-C5)被设计成超参数为了找到不掉点情况下最小参数量的网络,对于使用 UNeXt 架构的实验,遵循 C1 = 32、C2 = 64、C3 = 128、C4 = 160 和 C5 = 256。

TokMLP 设计思路

关于 Convolutional Stage 我们不做过多介绍了,在这一部分重点专注 Tokenized MLP Stage。从上一部分的图中,可以看到 Shifted MLP 这一操作,其实思路类似于 Swin transformer,引入基于窗口的注意力机制,向全局模型中添加更多的局域性。下图的意思是,Tokenized MLP 块有 2 个 MLP,在一个 MLP 中跨越宽度移动特征,在另一个 MLP 中跨越高度移动特征,也就是说,特征在高度和宽度上依次移位。论文中是这么说的:“我们将特征分成 h 个不同的分区,并根据指定的轴线将它们移到 j=5 的位置”。其实就是创建了随机窗口,这个图可以理解为灰色是特征块的位置,白色是移动之后的 padding。

补充:MLP拥有大量参数,计算成本高且容易过度拟合,而且因为层之间的线性变换总是将前一层的输出作为一个整体,所以MLP在捕获输入特征图中的局部特征结构的能力较弱。通过轴向移动特征信息, Shifted MLP可以得到不同方向的信息流,这有助于捕获局部相关性。该操作使得我们采用纯MLP架构即可取得与CNN相同的感受野。

解释过 Shifted MLP 后,我们再看另一部分:tokenized MLP block。首先,需要把特征转换为 tokens(可以理解为 Patch Embedding 的过程,感觉这个就是个普通卷积,而且作者为了保证conv后的矩阵减半,设置步幅为2,总之,有些编故事的意思了)。为了实现 tokenized 化,使用 kernel size 为 3 的卷积(patch_size=3, stride=2),这样会使得矩阵H和W减半,并将通道的数量改为 E,E 是 embadding 嵌入维度( token 的数量),也是一个超参数。然后把这些 token 送到上面提到的第一个跨越宽度的 MLP 中。

这里会产生了一个疑问,关于 kernel size 为 3 的卷积,使用的是什么样的卷积层?答:这里还是普通的卷积,文章中提到了 DWConv(DepthWise Conv),是后面的特征通过 DW-Conv 传递。使用 DWConv 有两个原因:(1)它有助于对 MLP 特征的位置信息进行编码。MLP 块中的卷积层足以编码位置信息,它实际上比标准的位置编码表现得更好。像 ViT 中的位置编码技术,当测试和训练的分辨率不一样时,需要进行插值,往往会导致性能下降。(2)DWConv 使用的参数数量较少。

这时我们得到了 DW-Conv 传递过来的特征,然后使用 GELU 完成激活。接下来,通过另一个 MLP(跨越height)传递特征,该 MLP 把进一步改变了特征尺寸。在这里还使用一个残差连接,将原始 token 添加为残差。然后我们利用 Layer Norm(LN),将输出特征传递到下一个块。LN 比 BN 更可取,因为它是沿着 token 进行规范化,而不是在 Tokenized MLP 块的整个批处理中进行规范化。上面这些就是一个 tokenized MLP block 的设计思路。

此外,文章中给出了 tokenized MLP block 涉及的计算公式:

其中 T 表示 tokens,H 表示高度,W 表示宽度。值得注意的是,所有这些计算都是在 embedding 维度 H 上进行的,它明显小于特征图的维度 HN×HN,其中 N 取决于 block 大小。在下面的实验部分,文章将 H 设置为 768。

实验部分

实验在 ISIC 和 BUSI 数据集上进行,可以看到,在 GLOPs、性能和推理时间都上表现不错。

下面是可视化和消融实验的部分。可视化图可以发现,UNeXt 处理的更加圆滑和接近真实标签。

消融实验可以发现,从原始的 UNet 开始,然后只是减少过滤器的数量,发现性能下降,但参数并没有减少太多。接下来,仅使用 3 层深度架构,既 UNeXt 的 Conv 阶段。显着减少了参数的数量和复杂性,但性能降低了 4%。加入 tokenized MLP block 后,它显着提高了性能,同时将复杂度和参数量是一个最小值。接下来,我们将 DWConv 添加到 positional embedding,性能又提高了。接下来,在 MLP 中添加  Shifted 操作,表明在标记化之前移位特征可以提高性能,但是不会增加任何参数或复杂性。注意:Shifted MLP 不会增加 GLOPs。

一些理解和总结

在这项工作中,提出了一种新的深度网络架构 UNeXt,用于医疗图像分割,专注于参数量的减小。UNeXt 是一种基于卷积和 MLP 的架构,其中有一个初始的 Conv 阶段,然后是深层空间中的 MLP。具体来说,提出了一个带有移位 MLP 的标记化 MLP 块。在多个数据集上验证了 UNeXt,实现了更快的推理、更低的复杂性和更少的参数数量,同时还实现了最先进的性能。

另外,个人觉得 带有移位 MLP 的标记化 MLP 块这里其实有点讲故事的意思了。

我在读这篇论文的时候,直接注意到了它用的数据集。我认为 UNeXt 可能只适用于这种简单的医学图像分割任务,类似的有 Optic Disc and Cup Seg,对于更复杂的,比如血管,软骨,Liver Tumor,kidney Seg 这些,可能效果达不到这么好,因为运算量被极大的减少了,每个 convolutional 阶段只有一个卷积层。MLP 魔改 U-Net 也算是一个尝试,在 Tokenized MLP block 中加入 DWConv 也是很合理的设计。

代码实现:

class shiftmlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

        self.shift_size = shift_size
        self.pad = shift_size // 2

        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    


    def forward(self, x, H, W):
        # pdb.set_trace()
        B, N, C = x.shape

        xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
        #pad,方便后面的torch.chunk
        xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
        #按照dim=1维度,分成 self.shift_size(5)个块
        xs = torch.chunk(xn, self.shift_size, 1)
        #torch.roll(x,y,d)将x,沿着d维度,向上/下roll y个值
        x_shift = [torch.roll(x_c, shift, 2) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
        x_cat = torch.cat(x_shift, 1)
        #x.narrow(*dimension*, *start*, *length*) → Tensor 表示取变量x的第dimension维,从索引start开始到(start+length-1)范围的值。
        x_cat = torch.narrow(x_cat, 2, self.pad, H)
        x_s = torch.narrow(x_cat, 3, self.pad, W)

        x_s = x_s.reshape(B,C,H*W).contiguous()
        x_shift_r = x_s.transpose(1,2)

        x = self.fc1(x_shift_r)

        x = self.dwconv(x, H, W)
        x = self.act(x) 
        x = self.drop(x)

        xn = x.transpose(1, 2).view(B, C, H, W).contiguous()
        xn = F.pad(xn, (self.pad, self.pad, self.pad, self.pad) , "constant", 0)
        xs = torch.chunk(xn, self.shift_size, 1)
        x_shift = [torch.roll(x_c, shift, 3) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
        x_cat = torch.cat(x_shift, 1)
        x_cat = torch.narrow(x_cat, 2, self.pad, H)
        x_s = torch.narrow(x_cat, 3, self.pad, W)
        x_s = x_s.reshape(B,C,H*W).contiguous()
        x_shift_c = x_s.transpose(1,2)

        x = self.fc2(x_shift_c)
        x = self.drop(x)
        return x

class shiftedBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()


        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):

        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x

Vision MLP —Swin-MLP

code:https://github.com/microsoft/Swin-Transformer

Swin MLP 代码来自 Swin Transformer 的官方实现。Swin Transformer 作者们在已有模型的基础上实现了 Swin MLP 模型,证明了 Window-based attention 对于 MLP 模型的有效性。

把张量 (B, H, W, C) 分成 window (B×H/M×W/M, M, M, C),其中M是 window_size。这一步相当于得到 B×H/M×W/M 个大小为 (M, M, C) 的 window。

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

把 window (B×H/M×W/M, M, M, C) 变回张量 (B, H, W, C)。

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

一个 Swin MLP Block

class SwinMLPBlock(nn.Module):
    r""" Swin MLP Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.padding = [self.window_size - self.shift_size, self.shift_size,
                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b

        self.norm1 = norm_layer(dim)
        # use group convolution to implement multi-head MLP
        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,
                                     self.num_heads * self.window_size ** 2,
                                     kernel_size=1,
                                     groups=self.num_heads)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # shift
        if self.shift_size > 0:
            P_l, P_r, P_t, P_b = self.padding
            shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0)
        else:
            shifted_x = x
        _, _H, _W, _ = shifted_x.shape

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # Window/Shifted-Window Spatial MLP
        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)
        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C//nH
        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,
                                                  C // self.num_heads)
        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C//nH
        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,
                                                       C // self.num_heads).transpose(1, 2)
        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)

        # merge windows
        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C

        # reverse shift
        if self.shift_size > 0:
            P_l, P_r, P_t, P_b = self.padding
            x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

注意 F.pad(x, [0, 0, P_l, P_r, P_t, P_b], “constant”, 0) 的对象是 x,维度是 (B, H, W, C)。
padding相当于是第3维 (C 这一维) 不填充,第2维 (W 这一维) 左右分别填充 P_l, P_r,第1维 (H 这一维) 左右分别填充 P_t, P_b。
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C:
这句代码把 shifted_x 分成 nW*B 个 windows,其中每个 window 的维度是 (window_size, window_size, C)。

# reverse shift
if self.shift_size > 0:
P_l, P_r, P_t, P_b = self.padding
x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
else:
x = shifted_x
这里是如果进行了 shift 操作,则最后取得结果也应该是没有 padding 的部分,正好是 shifted_x[:, P_t:-P_b, P_l:-P_r, :]。

一个 Swin MLP Block 的 FLOPs,注意 WSA 的计算量是:

FLOPs (WSA) = (window_size * window_size)^2 * dim * number_window

def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W

        # Window/Shifted-Window Spatial MLP
        if self.shift_size > 0:
            nW = (H / self.window_size + 1) * (W / self.window_size + 1)
        else:
            nW = H * W / self.window_size / self.window_size
        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

每个 stage 之间的 PatchMerging连接,把 resolution 变为一半,dim 变为2倍。

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def flops(self):
        H, W = self.input_resolution
        # norm
        flops = H * W * self.dim
        # reduction
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops
  • Patch Merging 操作把相邻的 2×2 个 tokens 给合并到一起,得到的 token 的维度是4C。
    Patch Merging 操作再通过一次线性变换把维度降为2C。

一个 Swin MLP Layer

class BasicLayer(nn.Module):
    """ A basic Swin MLP layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., drop=0., drop_path=0.,
                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinMLPBlock(dim=dim, input_resolution=input_resolution,
                         num_heads=num_heads, window_size=window_size,
                         shift_size=0 if (i % 2 == 0) else window_size // 2,
                         mlp_ratio=mlp_ratio,
                         drop=drop,
                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                         norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops
  • 包含 depth 个 Swin MLP Block。
    注意计算 FLOPs 的方式:每个 blk 和 downsample 都自带 flops() 方法,可以直接来调用。

PatchEmbedded 操作

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops
  • 和 ViT 的 Patch Embedded 操作一样,本质上是一个 K=patch size,s=patch size 的 nn.Conv2d 操作,注意卷积 FLOPs 的计算公式即可。

SwinMLP 整体模型架构

class SwinMLP(nn.Module):
    r""" Swin MLP

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin MLP layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        drop_rate (float): Dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               drop=drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        # adaptive average pool
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        # head
        flops += self.num_features * self.num_classes
        return flops
  • 由4个 Stage 组成,每个 Stage 由 BasicLayer 实现。
    传入的 depths 代表每个 Stage 的层数,比如 Swin-T 就是:[2, 2, 6, 2]。