DETR :End to End Object Detection with Transformers

目标检测领域的里程碑式的工作

https://arxiv.org/abs/2005.12872

code:https://github.com/facebookresearch/detr Facebook AI(meta AI)

DETRDetection Transformers

文章题目:简单明了,包含两个关键词:端到端、transformer

目标检测领域:从目标检测开始火到detr都很少有端到端的方法,大部分方法最后至少需要后处理操作(NMS,non-maximum suppression非极大值抑制)。无论是proposal based方法、anchor based方法、non-anchor based方法,最后都会生成很多预测框,如何去除这些冗余的框就是NMS要做的事情。

问题:有了NMS,模型调参就会很复杂,而且即使训练好了一个模型,部署起来也非常困难(NMS不是所有硬件都支持)。所以一个简单的、端到端模型一直是大家梦寐以求的,而detr的出现解决了这些痛点。

一、Detr目标:

1、不需要proposal、不需要anchor,直接利用transformer这种全局建模的能力,把目标检测看做是集合预测问题

2、因为有了这种全局建模的能力,detr不会有那么多冗余框,最后出什么结果就是什么结果,不需要NMS做后处理,让模型的训练和部署简单不少

目的:不想让大家觉得目标检测是比图像分类难很多的任务,都可以用简单的,优雅的框架做出来

二、摘要

作者说,他们就是把目标检测的任务看成是一个集合预测问题:目标检测本来任务就是给定一个图像,预测一堆框,每个框不仅要知道的其坐标,还要知道框里包含物体的类别,这些框就是一个集合,不同的图像对应的集合也是不同的,给定一个图片,我要预测这个集合

因此这篇文章就是把目标检测做成一个端到端的框架,把之前特别依赖人的先验知识的部分删掉了(NMS部分、anchor),一旦把这两个部分拿掉之后,我们也不用费尽心思设计这种anchor,最后不会出现这么多框,不会用到NMS,也不会用到很多超参去调。两个贡献:1、使用新的目标函数,通过二分图匹配的方式,强制模型输出一组独一无二的预测(没有那么多冗余框,每个物体理想状态下就会生成一个框)。2、另外使用encoder-decoder的架构。

两个小贡献:

1、decoder还有另外一个输入learned object query,类似anchor的意思(给定这些object query之后,detr就可以把learned object query和全局图像信息结合一起,通过不同的做注意力操作,从而让模型直接输出最后的一组预测框)

2、想法&&实效性:并行比串行更合适,并不是检测一个大物体前必须先检测一个小物体,或从左到右检测,我们希望越快越好

DETR的好处:

1、简单性:想法上简单,不需要一个特殊的library,只要硬件支持transformer或CNN,就一定支持detr

2、性能:在coco数据集上,detr和一个训练非常好的faster RCNN基线网络取得了差不多的效果,模型内存和速度也和faster RCNN差不多

3、想法好,解决了目标检测领域很多痛点,写作好

4、别的任务:全景分割任务上detr效果很好,detr能够非常简单拓展到其他任务上

三、引言

1、目标检测任务:对每一个感兴趣的物体,去预测一些框,和物体类别,就是一个集合预测问题。

2、现在大多数好用的目标检测器,都是用间接的方式去处理集合预测问题,(1)比如proposal方式(如RCNN系列工作),(2)anchor方式(YOLO系列,focal loss),non-anchor based方法(物体中心点center net,FCOS),他们都没有直接做集合预测任务,而是设计一个替代(回归、分类)解决目标检测问题。所有这些方法性能受限于后处理操作(NMS),由于用了anchor和NMS导致检测器都非常复杂,难以优化和调参。

3、端到端的思想已经在别的很多任务里大范围使用,而且使任务更加简单好用,我们不要先验知识,就是要用一个端到端网络。

detr流程(训练):

1、CNN提特征

2、特征拉直,送到encoder-decoder中,encoder作用:进一步学习全局信息,为近下来的decoder,也就是最后出预测框做铺垫。直观的解释为什么需要使用transformer encoder呢?如果使用了transformer encoder,那么每一个点或者说每一个特征就会跟着图片里面的其他的特征有交互了,这样大概就知道那块是那个物体,对于同一个物体就应该只出一个框而不是好多框,所以全局的建模有利于移除冗余的框。

3、decoder生成框的输出,当你有了图像特征之后,还会有一个object query(限定了你要出多少框),通过query和特征在decoder里进行自注意力操作,得到输出的框(文中是100,无论是什么图片都会预测100个框)

4、生成的100个框如何与ground truth这个框做匹配并计算 loss? :二分图匹配,如上图,我们计算100个预测的框和2个GT框的matching loss,决定100个预测框哪两个是独一无二对应到红黄色的GT框,用匹配的框去算目标检测的loss。而没有匹配到的98个框就会被标记为没有物体。

5、推理1、2、3一致,第四步loss不需要,直接在最后的输出上用一个阈值卡一个输出的置信度,置信度比较大(>0.7的)保留,置信度小于0.7的当做背景物体。

结果:

1、detr对大物体预测很准,归功于transformer,能进行全局建模(原来使用anchor的话就会受限于anchor大小)

2、缺陷:对小物体效果不好(多尺度、多特征,可以提高小物体的检测)后续改进:Deformable DETR

3、detr训练很慢,500个epoch(coco大多数模型一般训练几十个epoch就行)

检测效果:

detr由于使用transformer全局建模,没有用anchor,想检测多大物体就检测多大,所以检测大物体效果较好。detr框架太简单,没有多尺度特征,没有FPN,没有复杂的目标检测头,所以在小目标检测效果不好

四、相关工作

目标检测:

目前大多数的检测器是根据初始猜测做预测:

1、two-stage:初始猜测是中间的proposal

2、one-stage:初始猜测是anchor或物体中心点

最近一篇论文做了详细比较,发现他们的性能和刚开始的初始猜测非常相关,怎么做后处理对性能影响至关重要

怎么后处理:

1、集合思想:可学习的NMS方法、关系型网络,可以利用自注意力方法去处理物体之间的联系,得出独一无二的预测,就不需要后处理的步骤(性能较低)

解决:人工干预:手工设计的场景特征帮助模型学习,但是detr目标是想让目标检测任务更加简单,不希望用到过多人工先验知识

2、循环检测器:encoder-decoder:让detr工作主要原因:transformer

五、方法

分两块:1、基于集合的目标函数怎么做,作者如何通过二分图匹配把预测的框和GT框连接在一起,算得目标函数 2、detr具体模型架构

目标函数部分:

detr模型最后输出是一个固定集合,无论图片是什么,最后都会输出n个(本文n=100)预测框

问题:detr每次都会出100个输出,但是实际上一个图片的GT的bounding box可能只有几个,如何匹配?如何计算loss?怎么知道哪个预测框对应GT框?

作者这里把这个问题转换成了一个二分图匹配的问题:

二分图又称作二部图,是图论中的一种特殊模型。 设G=(V,E)是一个无向图,如果顶点V可分割为两个互不相交的子集(A,B),并且图中的每条边(i,j)所关联的两个顶点i和j分别属于这两个不同的顶点集(i in A,j in B),则称图G为一个二分图。简而言之,就是顶点集V可分割为两个互不相交的子集,并且图中每条边依附的两个顶点都分属于这两个互不相交的子集,两个子集内的顶点不相邻。

加权二分图匹配可以认为是有ABC三个工人,以及xyz三个工作,每个工人去做xyz工作的花费不同,如何去为每一个个人安排一个工作,使得最后我们的花费最低,可以使用遍历的方法,亦可以有很多高效的方法:匈牙利算法。

另外scipy包提供的linear sum assignment可以完成这个最优排列。detr论文里:代码也用的linear sum assignment函数来计算对应的匹配关系,只需要提供一个cost matrix矩阵就可以。a,b,c看成100个预测框,x,y,z看成GT框, cost matrix 损失矩阵未必都是正方形,最后丢到这个函数里面得到一个最优匹配。

那么对于目标检测任务,cost matrix 损失矩阵的值应该放些什么?loss包含两部分:分类loss、出框的准确度。所以也就是遍历所有的预测的框,那这些预测的框和gt框去算两个loss,然后把这个loss放到cost matrix矩阵 就可以了。这样就得到了对应gt的预测框(一对一),进而计算loss,梯度回传更新模型参数。

detr主体网络框架:

输入图片大小:3*800*1066(3:rgb),首先使用卷积网络获得特征:2048*25*34,然后降维变成256*25*34,然后给transformer添加位置信息:大小也是256*25*34,特征+位置作为transformer输入,特征拉直: 256*25*34 ==》 850*256,850就是序列长度,256是向量维度。后面的transformer encoder就跟普通的transformer encoder一样,输出==输出,仍然是850*256,接下来送入decoder里面。不同于一般的decoder,这里的object queries是一个可学习的,100*256大小的向量。在decoder里面做cross attension。输入 object queries ,另外一个输入是来自encoder的全局特征850*256.这两个去做自注意力操作,得到一个100*256的特征的decoder输出。最后添加一个检测头全连接层(FFN),获得类别预测(91类)和框预测(4:框的中心的+高度宽度),获得了100个框,利用匈牙利算法跟gt匹配,然后求loss,更新模型。

六、实验

检测效果:detr由于使用transformer全局建模,没有用anchor,想检测多大物体就检测多大,所以检测大物体效果较好。detr框架太简单,没有多尺度特征,没有FPN,没有复杂的目标检测头,所以在小目标检测效果不好。

下面的表格给出了 DETR 与基线 Faster RCNN 的定量性能对比。最上面一部分的 Faster RCNN 的性能结果是 Detectron2 的实现,之所以将 Faster RCNN 分成两部分,是因为 DETR 中使用了近年来很多新的训练 trick,如 GIoU loss、更强的数据增强策略、更长的训练时间,因此作者团队添加这些策略重新训练了 Faster RCNN,以作公平的对比。

近年来的新的训练策略对于目标检测模型的提升非常明显。对比表格的第一、第二部分,完全相同的模型,只是用了更优的训练策略,基本能稳定涨两个点。在同样的训练策略、网络规模大小的情况下,DETR 比 Faster RCNN 高 1-2 个点。对比表格的后两部分可以观察到这一点,DETR 对比基线的 Faster RCNN 还是还是有提升的。

DETR 在大物体的检测上远超 Faster RCNN,但是在小物体的检测上却也低了不少。

表格的后三列分别是小、中、大物体的检测性能,可以观察到 DETR 在大物体的检测上更出色,但是对于小物体的检测甚至远不如 Faster RCNN。大物体检测性能的提升得益于 Transformer 结构的全局建模能力,且没有预置的固定 anchor 的限制,因此预测框想多大就多大。而 DETR 在小物体上表现不佳,是因为本文中 DETR 的模型还是一个比较简单的模型,没有做很多针对目标检测的优化设计,比如针对小物体、多尺度的 FPN 设计。DETR 的网络结构还有待后续工作来改进。

表 1 detr和faster RCNN的对比,+表示用更好的训练策略把三个模型重新训练一遍

gflops参数:每秒进行的浮点运算次数,flops越小,模型越小,跑起来越快?X。如果更关心速度,比较fps

首先我们来看对于 Encoder 的可视化,下图展示了对于一组参考点的 Encoder 注意力热力图的可视化,即参考点对于图像中所有其他点自注意力值的大小。可以观察到,Transformer Encoder 基本已经能够非常清晰地区分开各个物体了,甚至热力图已经有一点实例分割的 mask 图的意思了。在有一定遮挡的情况下(左侧两头牛),也能够清楚地分开哪个是哪个。这种效果正是 Transformer Encoder 的全局建模能力所带来的,每个位置能够感知到图像中所有的其他位置。因此能够区分出图像中的不同物体,从而对于一个物体,尽量只出一个预测框。

通过前面的可视化,我们已经看到,Encoder 学习了一个全局的特征,基本已经能够区分开图中不同的物体。但是对于目标检测来说,大致地区分开不同的物体是不够的,我们还需要精确的物体的边界框坐标,这部分就由 Decoder 来做。

下图在 Decoder 特征中对每个不同的物体做了注意力的可视化,比如左图中的两头大象分别由蓝色和橙色表示。可以观察到,Decoder 网络中对于每个物体的注意力都集中在物体的边界位置,如大象的鼻子、尾巴、象腿等处。作者认为这是 Decoder 在区分不同物体边界的极值点(extremities),在 Encoder 能够区分开不同的物体之后,Decoder 再来关注不同物体边界的具体位置,最终精准地预测出不同物体的边框位置。因此,Encoder-Decoder 的结构是必要的,它们各司其职,一个都不能少。

扩展到全景分割任务

作者同时还将该网络应用于全景分割任务中.增加一个分割的head就可以。

下图是20个object query可视化(n=100,这里只有20个)

object query 到底学了什么(绿色代表小的bounding box,红色代表大的横向bounding box,蓝色代表大的竖向bounding box)object query和anchor有些像,anchor是提前定一些bounding box,把预测和这些提前定好的bounding box对比,object query是可以学习的。以第一个 object query 来说:对于一个图片, object query 会去问图片的左下角有没有小物体,以及中间有没有横向的大物体。

为了说明端到端的 DETR 框架的简洁性,作者在论文末尾给出了 DETR 模型定义、推理的 “伪代码”,总共不到 50 行。之所以这里的伪代码要加引号,是因为其实这已经不算是伪代码了,而是直接可运行的 PyTorch 代码。当然这个版本缺少了一些细节,但也完全能够展现出 DETR 的流程了。该版本直接用来训练,最终也能达到 40 的 AP。读者可以对应伪代码再过一遍刚才介绍的 DETR 完成流程,体会一下一个端到端的目标检测框架有多幺简洁。

代码实现:

后续关于detr的改进工作:

1、Deformable DETR: Deformable Transformers for End-to-End Object Detection

2、Omni-DETR: Omni-Supervised Object Detection with Transformers

3、UP-DETR: Unsupervised Pre-training for Object Detection with Transformers

4、PnP-DETR: Towards Efficient Visual Analysis with Transformers

5、SMAC-DETR

6、DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR

7、Accelerating DETR Convergence via Semantic-Aligned Matching

8、DN-DETR: Accelerate DETR Training by Introducing Query DeNoising

9、Open-Vocabulary DETR with Conditional Matching

10、OW-DETR: Open-world Detection Transformer

多模态| ALBEF

视频:多模态论文串讲

https://arxiv.org/abs/2107.07651

code: https://github.com/salesforce/ALBEF

写在前面:最近看了很多多模态的工作,现有的设计有哪些不足?我们又该如何去改进呢?首先来看模型的结构,因为需要处理文本和图片,所以模型开始需要有两个分支,分别抽取图像和文本特征。但是在多模态领域,视觉特征的重要性远远大于文本特征,所以要使用更强大的vision Embed,比如vit,同时对于多模态任务,多模态之间的融合也是十分重要的,也要保证模态融合的模型也要尽可能的大,因此网络应该跟(c)相近。模型确定了,接下来如何去训练呢?我们知道CLIP模型使用了一个对比学习的loss:ITCloss,这个效果很好,所以可以使用。另外常见的两个loss: image text matching(ITM),另一个是masked language modeling(MLM) 也可以继续使用。再回到ALBEF论文,其实它就是按照上述思路进行的设计。

它由一个图像编码器、一个文本编码器和一个多模态编码器组成。文章提出了一种图像文本对比损失,在图像文本融合之前对图像文本进行统一表示建模。图像文本匹配损失和掩码语言建模损失被应用于学习图像和文本之间的多模态交互。为了改进噪声数据的学习,我们使用动量模型生成伪目标来作为训练期间的额外监督。

为了改进在噪声监督下的学习,作者提出了动量蒸馏(MoD) ,使模型能够利用一个更大的web数据集。在训练过程中,作者通过取模型参数的移动平均来保持模型的动量版本,并使用动量模型生成伪目标作为额外的监督。 动量蒸馏可以解释为一种在线自我蒸馏的形式,其中使用学生模型组成的集合作为老师。类似的方法已经在半监督学习、标签噪声学习以及最近的对比学习中进行了探索应用。与现有研究不同,本文从理论上和实验上表明,动量蒸馏是一种通用的学习算法,可以提高模型在许多V+L 任务上的性能。

Pre-training Objectives

作者对ALBEF进行了三个目标的预训练:单模态编码器上的图像-文本对比学习(ITC) 、掩蔽语言建模(MLM) 和多模态编码器上的图像-文本匹配(ITM) 。作者通过在线对比 hard negative挖掘来改进ITM。

Image-Text Contrastive Learning

图像-文本对比学习的目的是在融合预训练更好的单模态表示。它学习了一个相似性函数,使匹配的图像-文本对具有更高的相似性得分。和是将[CLS]嵌入映射到标准化的低维(256d)表示的线性变换。

受MoCo的启发,作者维护了两个队列来存储动量单模态编码器的最新的M个图像-文本表示。动量编码器的归一化特征记为和。作者定义了和。

对于每个图像和文本,作者计算softmax归一化的图像到文本和文本到图像的相似度如下:

其中,τ是一个可学习的温度参数。设和表示ground truth的one-hot形式相似性,其中负对的概率为0,正对的概率为1。图像文本对比损失定义为p和y之间的交叉熵H:

Masked Language Modeling

Masked Language Modeling 同时利用图像和上下文文本来预测mask词。作者以15%的概率随机mask输入token,并用特殊token [MASK]替换它们。设表示mask文本,表示模型对mask token的预测概率。MLM使交叉熵损失最小化:

其中是一个one-hot形式的词汇分布,ground truth token的概率为1。

Image-Text Matching

图像-文本匹配可以预测一对图像和文本是正的(匹配)还是负的(不匹配)。作者使用多模态编码器的输出嵌入的[CLS] token作为图像-文本对的联合表示,并附加一个全连接(FC)层,然后是softmax来预测一个两类概率。

其中,是一个表示ground truth标签的二维one-hot向量。

作者提出了一种基于零计算开销的ITM任务进行 hard negatives采样的策略。如果负的图像-文本对共享相似的语义,但细粒度细节不同,那么它们是很难的。作者利用对比相似性来寻找batch内的 hard negatives。

对于一个batch中的每一幅图像,作者按照对比相似性分布从同一batch中抽取一个负文本,其中与图像更相似的文本有更高的机会被采样。同样地,作者还为每个文本采样一个hard negative图像。

Momentum Distillation

用于预训练的图像-文本对大多是从网络中收集起来的,而且它们往往会有噪声。正样本对通常是弱相关的:文本可能包含与图像无关的单词,或者图像可能包含文本中没有描述的实体 。

对于ITC学习,图像的负样本文本也可能与图像的内容相匹配。对于MLM,可能存在其他与描述图像相同(或更好)的标注不同的词。然而,ITC和MLM的one-hot标签会惩罚所有负标签预测,不管它们的正确性如何。为了解决这个问题,作者提出从动量模型生成的伪目标中学习。动量模型是一个连续发展的教师模型,它由单模态和多模态编码器的指数移动平均版本组成。

在训练过程中,训练基础模型,使其预测与动量模型的预测相匹配。具体来说,对于ITC,作者首先使用动量单模态编码器的特征计算图像-文本相似性,这个可以认为是一个softmax score,不再是一个 one hot向量。这样在模型训练的时候,我们希望在训练原始model的时候,不只是让预测跟目标值one hot尽可能接近,也希望能够和动量模型的输出保持一致,这样就能达到一个比较好的折中点,很多信息从one hot label来学习,但是当one hot label是错误的或者是有噪声的时候,我们希望这个稳定的动量模型提供一些改进。

多模态预训练 | ViLT

paper: https://arxiv.org/abs/2102.03334 ICML 2021

code: https://github.com/dandelin/ViLT

图1 Visual comparison of conventional VLP architectures
and our proposed ViLT.

视觉文本多模态任务,极其简单的多模态结构。模态的特征抽取做到了极小化,主要的计算量放在后边的模态融合上,提高了推理速度。多模态领域里程碑式工作。将区域特征,region 从多模态框架中移除。

Vision and Language Pre-training(VLP)已经已经在视觉语言的多模态下游任务中发展的很好。然而,当前VLP的工作主要集中在图像特征抽取上,一般来讲,图像特征抽取的越好,下游任务中的表现就越好。但是,现在主要有两个问题,一是效率太低,速度太慢,抽取图像特征花费大量时间,比多模态融合都多。我们应该花费更多时间在融合上。第二个是,你用一个预训练好的模型去抽取特征,表达能力受限。目标检测数据集不够大,规模不够大。如果模型不是端到端学习,只是从预训练模型抽取特征,大概率来说不是最优解。

Motivation

目前参数量最小的多模态Transformer方法。ViLT使用预训练的ViT来初始化交互的transformer,这样就可以直接利用交互层来处理视觉特征,不需要额外增加一个视觉encoder(如Faster-RCNN)。

Contribution

  1. 第一个基于patch projection的多模态预训练模型,其是首个使用patch projection来做visual embedding的方法。
  2. 证明了可以将BERT的方法和Vison Transformer结合起来用于多模态transformer
  3. 体现了全词掩码在预训练时以及图像增强在微调时的重要性。

Method

现有的视觉语言模型的三种结构类别:

VE = Vision Embedding

TE = Text Embedding

MI = Modality Interaction

上图是4种不同类型的VLP模型示意图。其中每个矩形的高表示相对计算量大小,VE、TE和MI分别是visual embedding、text embedding和modality interaction的简写。

作者提出这4种类型的主要依据有两点:

1.在参数或者计算上,两种模态是否保持平衡。

2.在网络深层中,两种模态是否相互作用。

VSE、VSE++和SCAN属于(a)类型。对图像和文本独立使用encoder,图像的更重,文本的更轻,使用简单的点积或者浅层attention层来表示两种模态特征的相似性。

CLIP属于(b)类型。每个模态单独使用重的transformer encoder,使用池化后的图像特征点积计算特征相似性。

ViLBERT、UNTER和Pixel-BERT属于(c)类型。这些方法使用深层transformer进行交互作用,但是由于VE仍然使用重的卷积网络进行特征抽取,导致计算量依然很大。

作者提出的ViLT属于(d)类型。ViLT是首个将VE设计的如TE一样轻量的方法,该方法的主要计算量都集中在模态交互上。

Modality Interaction Schema

模态交互部分可以分成两种方式:一种是single-stream(如BERT和UNITER),另一种是dual-stream(如ViLBERT和LXMERT)。其中single-stream是对图像和文本concate然后进行交互操作,而dual-stream是不对图像和文本concate然后进行交互操作。ViLT延用single-stream的交互方式,因为dual-stream会引入额外的计算量。

现有的VLP模型的text embedding基本上都使用类BERT结构(图1),但是visual embedding存在着差异。在大多数情况下,visual embedding是现有VLP模型的瓶颈。visual embedding的方法总共有三大类,其中region feature方法通常采用Faster R-CNN二阶段检测器提取region的特征,grid feature方法直接使用CNN提取grid的特征,patch projection方法将输入图片切片投影提取特征。ViLT是首个使用patch projection来做visual embedding的方法。

网络结构ViLT

作者提出的ViLT可以认为是目前最简单的多模态Transformer方法。ViLT使用预训练的ViT来初始化交互的transformer,这样就可以直接利用交互层来处理视觉特征,不需要额外增加一个视觉encoder。

文本特征输入部分,将文本看成一个词序列,通过word embedding matrix转化成word embedding,然后和position embedding进行相加,最后和modal-type embedding进行concate。

图像特征输入部分,将图像切块看成一个图像块序列,通过linear projection转化成visual embedding,然后和postion embedding进行相加,最后和modal-type embedding进行concate。

其中word embedding和visual embedding通过可学习的modal-type embedding标志位来区分,其中0标志位表示word embedding部分,1标志位表示visual embedding部分。

wrod embedding和visual embedding分别都嵌入了一个额外的可学习[class] embedding,方便和下游任务对接。

Pretraining Objectives

ViLT预训练的优化目标有两个:一个是image text matching(ITM),另一个是masked language modeling(MLM)

ImageText Matching:随机以0.5的概率将文本对应的图片替换成不同的图片,然后对文本标志位对应输出使用一个线性的ITM head将输出feature映射成一个二值logits,用来判断图像文本是否匹配。另外ViLT还设计了一个word patch alignment (WPA)来计算teextual subset和visual subset的对齐分数。

Masked Language Modeling:MLM的目标是通过文本的上下文信息去预测masked的文本tokens。随机以0.15的概率mask掉tokens,然后文本输出接两层MLP与车mask掉的tokens。

Whole Word Masking:另外ViLT还使用了whole word masking技巧。whole word masking是将连续的子词tokens进行mask的技巧,避免了只通过单词上下文进行预测。比如将“giraffe”词tokenized成3个部分[“gi”, “##raf”, “##fe”],可以mask成[“gi”, “[MASK]”, “##fe”],模型会通过mask的上下文信息[“gi”,“##fe”]来预测mask的“##raf”,就会导致不利用图像信息。

Experiment

本文提出的方法在效率上大大提升且表现出相似的性能,相比于region feature的方法速度快了60倍,相比于grid feature的方法快了4倍,而且下游任务表现出相似甚至更好的性能。

如图所示,ViLT相比于region feature的方法速度快了60倍,相比于grid feature的方法快了4倍,而且下游任务表现出相似甚至更好的性能。

缺点:

1、性能不够高,在一些数据集上得表现比不过C类方法,有可能因为对于现有的任务来说,因为数据集的bias,或者这个任务需要更多的视觉信息,因此需要更多得视觉部分,最后的效果才能好。

2、虽然推理时间快,但是训练速度很慢。只是结构上简化了多模态学习,但一般人还是玩不起。

Visual Prompt–视觉模板

Visual Prompt Tuning

https://github.com/KMnP/vpt

Exploring Visual Prompts for Adapting Large-Scale Models

https://github.com/hjbahng/visual_prompting

Visual Prompt Tuning

最近NLP领域提出了Prompt范新式,企图革新原先的Fine-tuning方法,而在CV领域中,Prompt其实可以理解为图像label的设计,从这个角度看,Prompt(预测文本中mask的字符,类似完形填空)其实是介于Image caption(迭代预测出每一个字符)和one-hot label(one-hot可以认为是prompt的特例,单字符通过text encoder成one-hot)之间的任务。最近在Visual-Language Model(缩写VLM)任务中,prompt开始展现出强大的能力.

Fine-tuning中:是预训练语言模型“迁就“各种下游任务。具体体现就是上面提到的通过引入各种辅助任务loss,将其添加到预训练模型中,然后继续pre-training,以便让其更加适配下游任务。总之,这个过程中,预训练语言模型做出了更多的牺牲。

Prompting中,是各种下游任务“迁就“预训练语言模型。具体体现也是上面介绍的,我们需要对不同任务进行重构,使得它达到适配预训练语言模型的效果。总之,这个过程中,是下游任务做出了更多的牺牲。

Abstract

目前调整预训练模型的方法是full fine-tuning,即完全微调。本文介绍Visual Prompt Tuning(VPT)作为一种有效的用于大规模Transformer的视觉微调。它只需要在输入空间引入少量(不到1%的模型参数)的可训练参数,同时冻结backbone。会发现在很多情况下,优于完全微调。

Introduction

对于大规模模型适应下游任务时,通常的策略是进行端到端的全面微调,然而这种策略需要为每个人物存储部署单独的主干参数,代价比较高,毕竟现在的Transformer体系结构比较大。

一种简单的方法是使用已经完善的其他策略,如下图(a):仅微调参数的子集,如分类器头部或者偏差项。之前的研究还会试着向主干添加额外的残差结构或者adapter,可以对Transformer实施类似的策略。然而,这些策略会在准确度上执行完全微调。

作者试图探索一种不同的方法:并不通过改变或者微调预训练好的Transformer本身,而是修改其输入。如下图(b)所示:将少量特定任务的可学习参数引入输入空间,同时在下游训练期间冻结backbone。实践中,这些附加参数只是预先加入到Transformer每层输入序列中,并在微调时和线性头一起学习。

在预训练主干用ViT的24个跨域的下游任务中,VPT优于了其他迁移学习的baseline,有20个超过了完全微调,同时保持了为每个单独任务储存较少参数的优势。(NLP任务中,prompt tuning旨在某些情况下才匹配完全微调的性能)。如下图(c)所示,VPT在地数据区尤其有效,结果也进一步表明,VPT是适应不断增长的视觉主干的最有效方法之一。

Visual-Prompt Tuning (VPT) vs . other transfer learning methods. (a) Current
transfer learning protocols are grouped based on the tuning scope: Full fine-tuning,
Head-oriented, and Backbone-oriented approaches. (b) VPT instead adds extra pa-
rameters in the input space. (c) Performance of different methods on a wide range
of downstream classification tasks adapting a pre-trained ViT-B backbone, with mean
and standard deviation annotated. VPT outperforms Full fine-tuning 20 out of 24 cases
while using less than 1% of all model parameters

Related Work

迁移学习两种代表性方法:Adapter在每个Transformer层后插入一个额外的小模块。通常包含一个线性向下头像、线性向上投影及一个残差连接。BitFit是在微调网络时更新偏置项并冻结其余backbone参数。这些方法在NLP已经成熟运用。作者则进一步实验表明了VPT在视觉任务的Transformer的模型调整上性能更加良好。

Prompt最初指的是在输入文本中预编语言汁了,以便预训练好的LM(Language Model)能够“理解”任务。最近的工作则是将prompt视为任务特定的连续向量,并在微调过程中通过梯度直接优化,即Prompt Tuning。prompting依然局限于文本编码器的输入。作者是第一个解决(同样的方法能成功的应用到视觉主干)并研究视觉prompt的普遍性和可行性的工作。

Approach

整体框架图如下图所示:

Visual-Prompt Tuning:

给定一个预先训练好的Transformer,在Embed层后的输入空间引入一组d维的p连续embedding。在微调过程中,只有prompt会被更新,主干将会冻结,根据加入prompt的层数量分为浅VPT和深VPT。

浅VPT :

Prompt仅插入第一层。每一个prompt token都是一个可学习的d维参数。集合和浅VPT表示如下:

VPT-Deep: Prompt被插入每一层的输入控件。集合和深VPT表示如下:

VPT对于多个下游任务都是有帮助的,只需要为每个任务存储学习到的prompt和分类头,重新使用预训练的Transformer,从而显着降低存储成本。

Experiments:

上图是关于prompt的位置,本文提出的是prepend,与直接在embedding上添加对比效果更好。除此之外,作为前置像素或者concat通道的效果也都在下降。

下图则是对prompt长度、深度的消融实验:

最佳提示长度因任务而异,即使只有一个prompt,深VPT的效果仍显着优于另外两种方法。

下图为最终输出的消融实验:

补充:

CoOp

CoOp明显是受到了AutoPrompt的启发,并且CoOp发现CLIP实际上就是prompt在visual-language model中的一个应用,于是CoOp在CLIP的基础上进一步进行改进。

CoOp先在四个数据集上做实验,发现更合理的prompt能够大幅度的提升分类精度尤其是使用了本文提出的CoOp之后,最终的分类精度远超CLIP人为设计的prompt。

和CLIP的主要不同之处在于,CoOp在CLIP的第二个阶段中引入了context optimization。具体的,CoOp将prompt设计为:

t=[V]1[ V]2…[V]M[CLASS]

其中每个[V]M向量跟word embedding的维度相同,可以理解为可学习的context,并且所有类别对应的context共享参数。

Exploring Visual Prompts for Adapting Large-Scale Models

这篇文章参考了Ian Goodfellow等人(2018)对抗样本中的对抗重编程思想, 对抗重编程的目标是我使用一个任务A的网络(无需重新训练该网络)来做任务B。我们设一个经过训练的模型,其原本的任务是,给定输入x,会得到输出f(x)现在有一个攻击者,其对抗任务是对于给定的输入x~,会得到输出g(x~),这里x~和x不一定需要是同域的。这看起来是不可行的,但是攻击者通过学习对抗重编程函数hf(.;θ)和hg(.;θ)来实现这两个任务之间的映射。hf(.;θ)用于将输入从x~所在的域转换到x所在的域,也就是说,经过hf的处理后得到的hf(x~;θ)对于f而言是有效的输入;而hg则将f的输出f(h(x~;θ))映射会g(x~)的输出。

图 1:对抗重编程图示。(a)将 ImageNet 标签映射至对抗任务标签(图像中的方块)。(b)来自对抗任务的图像(左)被嵌入对抗程序的中心(中),得到对抗图像(右)。该对抗程序使 Inception V3 网络执行计算图像中方块数量的任务。(c)使用对抗图像进行推断的图示。把对抗图像输入该网络时,网络预测映射至对抗任务的 ImageNet 标签。

图2提供了适应预训练模型的不同方法的摘要。微调和线性探测的用法非常灵活:它们可用于使模型适应新的输入域或具有不同输出语义的新任务。但是,在线性探针的情况下,在微调和模型输出(通常在倒数第二层的激活)的情况下,他们还需要一定程度的访问模型:参数。域的适应性是模型适应的有趣替代方法,因为它仅使用图像到图像翻译等技术来修改模型的输入[50,19]。像域的适应性一样,视觉提示也将输入修改为模型。因此,一旦最终用户找到了视觉提示,就不需要在测试时间控制模型本身。这打开了独特的应用程序;例如,用户可以将适应域的图像馈送到只能通过输入来操纵的在线API。域的适应性重点是调整源域以看起来像目标域,同时需要源和目标数据集。另一方面,我们认为视觉提示可以以更任意的方式引导模型。例如,可以通过扰动输入像素来使用新的输出语义进行一个完全不同的分类任务来执行完全不同的分类任务。同样,尽管域适应方法通常是输入条件,但我们在本文中探索的视觉提示是固定的(即输入 – agnostic),如NLP,例如在NLP中,将相同的自然语言提示添加到所有模型查询中

CLIP: Contrastive Language-Image Pre-Training

CLIP论文讲解

背景

论文来自 Open AI 2021 年提出的一个成果,相关可参考信息: github、 paper主页 。 之前其实并不太了解多模态预训练领域的成果,最近看到了这篇质量很高的成果。

Hugging Facehttps://huggingface.co/openai/clip-vit-base-patch32(预训练模型库)

效果

我们可以运行这colab,该作者将 Unsplash 的所有素材计算了 clip image embedding ,然后使用 clip word embedding 进行配图。

preview

效果看起来似乎不错,几乎实现了通过一句话就找到合适的图片。不可否定,会存在大量的badcase。但是在不需要fine-tune/下游任务,直接zero-shot得到的embedding可以实现这样的效果已经很厉害了。

作者团队来自 OPEN AI

CLIP工作:

1 方法简单,效果好

2 迁移学习能力强(已训练好的模型,可以在任意数据集上取得好效果)

Q1 CLIP 是什么?how 做zero-shot(一种分类方式)?

Q2 CLIP How 预训练?

利用信号(来自自然语言处理)训练一个模型(迁移效果好)

Q3 经过预训练能得到什么?

A3 仅得到图片或文本的特征,没有在分类任务上继续做训练或微调。即CLIP没有分类头。

Q4 没有分类头,how 做推理?

A4 利用自然语言的方法—prompt template;

将1000个类,生成一个1000个句子(object),

例子:plane 变成 object

1000个句子通过文本编码器(text-Encode)生成1000个特征。

Q5 直接从1000个类里面抽取特征也可以, why 还要进行Prompt- template ?

A5 在预训练时,model 看到的是sample-pair,若在推理时,把所有文本变成一个单词(word),导致model看到的东西和预训练时不一样,导致识别效果稍下降。

Q6 如何将1000分类变成 1000个句子(object)?why 这样做?

A6 2个方法:prompt engineering 和 prompot ensambol;

提高模型准确率,且不需要重新训练模型

Q7 prompt template 操作之后要干嘛?

A7 input 图片,经过image_Encode 得到图片特征,利用image_feature 和 text-feature 计算相似性,挑出值(最相似),进而完成分类任务。

Q8 how 理解分类任务?

A8 judge image 中有哪些物体

Q9 text and image 可以改吗?

A9 yes (all of anything)

Q10 若用imageNet做训练,input三轮车(image),why 得到车,而不是,三轮车?

A10 因为,imageNet 无法实时更新已有类别。

但CLIP可以实时更新,故 imput = output。

这也是CLIP的强大之处,彻底摆脱了categoricel label 限制。

为了提高model泛化性,作者提出新办法,从 text 中提取监督信号。(正是有了监督信号(覆盖范围广)的存在,model 的泛化能力得到提高)作者利用4亿 text-image-pair-dataset ,选择自监督训练方式,进而训练模型。

CLIP 利用多模态对比学习完成训练,并可以做物体分类(即prompt),这种分类不限于已有类别,可扩展到新类别。(即当前学到的model,可以直接在 downstream tasks(下游任务)上做推理。

2017年有人研究,但是影响力小,效果差:

主要有3个工作(均基于transformer)和CLIP像,但有区别:VIrTex:用自回归预测方式,做model预训练,ICMLM :用完型填空方式,做model预训练,ConVIRT:和CLIP类似,但仅在医疗图像上做实验。但由于data 和 model规模小,所以效果不好。

利用自然语言的监督信号,来训练好的视觉模型。在自监督学习(完型填空)的范式下,NLP可以利用(取不尽)的文本监督信号。用此方法训练出的模型,简单,泛化力强,为多模态训练铺路。

why 用自然语言监督信号训练视觉模型?

1 无需标注这些数据(数据规模变大)

2 此时监督信号是文本(不是n选1 的标签),意味着input,output 自由度大了很多.

3 因为image-text-pair数据,model所学特征不单是视觉特征,而是多模态特征。当image和语言联系在一起,便容易做zero-shot迁移学习。

若仅做单模态自监督学习,无论是单模态对比学习(MOCO),还是单模态掩码学习(MAE),model仅学到视觉特征,无法和自然语言联系在一起,依旧很难做zero-shot的迁移。需要大量的image-text-pair(4亿个Image-text-pair)

总结:用 文本监督信号来训练视觉model 这种做法很有潜力。

整个训练过程:

给定一张image,来预测文本,会产生较大歧义(即可能性太多);若逐字句预测文本,太难了。会导致模型训练慢因此采用对比学习,让 model 判断,image 和 text 是否配对。把 ” 训练任务 “ 换成 ” 对比任务 “ ,训练效率提高4倍

2个输入:image 和 text归一化,投射层:将单模态变成多模态,获得 n 个图像的特征,n 个文本的特征。计算 image-feature 和 text-feature相似度。利用相似度做分类。利用交叉熵目标函数计算loss

细节:

1 由于收集数据大,model 不存在 overfitting

简化了工作:当训练CLIP – model时,对应的 image-Encode 和 text-Encode 无需进行预训练

2 在多模态训练中,投射时,用线性投射层,

非线性投射层(作者推测,适配纯 Image 单模态学习),带来10个点的性能提升

3 使用 ” 随即裁减 “ 进行数据增强

4 数据集和 model 太大,不好调参

5 temperature parm(极重要超参数)稍调,model 性能会提高很多 但作者,将其设置成,可学习的标量

模型选择和参数设置:

视觉方面:

训练8个model;

ResNet = 5个,VIT = 3个

残差网络变体ResNet50*4:*16:*64:用 efficientNet-style 方法 将input-image 大小,channel宽度,model-depth 做微调

针对 transformer-model,作者选择数据集 VIT-B/32:/16:/64(阿拉伯数字表示patch大小)

文本方面:transformer

all-model 训练了epoch = 32;Adam optimizer优化器;手动调整超参数,用ResNet-50作为超参搜索,为了快速调参(训练epoch = 1)训练时:选用 batch-size = 32768(很大)(此model在很多机器上做分布式训练)

CLIP 文章的核心 = Zero-shot Transfer

作者研究迁移学习的动机:之前自监督or无监督的方法,主要研究 frature 学习的能力,model的目标是学习泛化性能好的特征,虽然学习到good-feature,但down-work中,还是需要有标签数据做微调。作者想仅训练一个model,在down-work中不再微调。

衡量model 学到的feature 好不好的方法有主要有2种:第一种:linear:冻结训练好的model,再训练一个分类头。第二种:微调:把整个网络放开,做end-to-end的学习。微调的优点: 灵活、当down-work数据集大,微调效果好。但这里作者使用只训练liner分类头: CLIP本就用来研究更数据集无关的训练方式,若用 “ 微调 “ 方法,无法判断预训练model效果如何。(因为,如果预训练model效果不好,经过在down-work上做微调,会导致最终结果好。)

CLIP这么强大,它有什么缺点?

平均来看,CLIIP可以和机械模型(ResNet-50(在ImageNet上训练))持平

若继续增加数据集和model规模,CLIP性能可以继续提高,但是代价很大(需提高计算和数据的高效性)

zreo-shot结果并不好

1 在细分类数据集上,CLIP效果低于(有监督训练)ResNet-50(极限网络)

2 CLIP无法处理抽象概念原因:CLUP无法区分 what is 异常?what is 安全?例如:数一数图片中的物体个数;在视频中,区分这一帧是异常还是非异常;作者提出:在很多领域,CLIP性能和瞎猜差不多

3 若数据集中的data 已经 out-of-distribution,那么CLIP-model泛化照样差;例子:在MNIST数据集上,CLIP准确率仅有88% 。推测原因:作者收集的数据集有4亿个样本,但没有和MINIS长得像的,所以MINIS数据集对于CLIP来说就是out-of-distribution数据集

评价

创新度高1 打破固定类别标签做法2 放飞视觉model训练过程3 引发后续大量工作

有效性高1 大数据集,效果好2 泛化性能好3 zero-shot性能超过人类

运用 BERT 的 MLM 模型进行小样本学习

转载自《必须要GPT3吗?不,BERT的MLM模型也能小样本学习》《P-tuning:自动构建模版,释放语言模型潜能》,作者:苏剑林,部分内容有修改。

大家都知道现在 GPT3 风头正盛,然而,到处都是 GPT3、GPT3 地推,读者是否记得 GPT3 论文的名字呢?事实上,GPT3 的论文叫做《Language Models are Few-Shot Learners》,标题里边已经没有 G、P、T 几个单词了,只不过它跟开始的 GPT 是一脉相承的,因此还是以 GPT 称呼它。顾名思义,GPT3 主打的是 Few-Shot Learning,也就是小样本学习。此外,GPT3 的另一个特点就是大,最大的版本多达 1750 亿参数,是 BERT Base的一千多倍。

正因如此,前些天 Arxiv 上的一篇论文《It’s Not Just Size That Matters: Small Language Models Are Also Few-Shot Learners》便引起了笔者的注意,意译过来就是“谁说一定要大的?小模型也可以做小样本学习”。显然,这标题对标的就是 GPT3,于是笔者饶有兴趣地点进去看看是谁这么有勇气挑战 GPT3,又是怎样的小模型能挑战 GPT3?经过阅读,原来作者提出通过适当的构造,用 BERT 的 MLM 模型也可以做小样本学习,看完之后颇有一种“原来还可以这样做”的恍然大悟感~在此与大家分享一下。

冉冉升起的 MLM

MLM,全称“Masked Language Model”,可以翻译为“掩码语言模型”,实际上就是一个完形填空任务,随机 Mask 掉文本中的某些字词,然后要模型去预测被 Mask 的字词,示意图如下:

BERT 的 MLM 模型简单示意图

其中被 Mask 掉的部分,可以是直接随机选择的 Token,也可以是随机选择连续的能组成一整个词的 Token,后者称为 Whole Word Masking (WWM)。

开始,MLM 仅被视为 BERT 的一个预训练任务,训练完了就可以扔掉的那种,因此有一些开源的模型干脆没保留 MLM 部分的权重,比如 brightmart版 和 clue版 的 RoBERTa,而哈工大开源的 RoBERTa-wwm-ext-large 则不知道出于什么原因随机初始化了 MLM 部分的权重,因此如果要复现本文后面的结果,这些版本是不可取的。

然而,随着研究的深入,研究人员发现不止 BERT 的 Encoder 很有用,预训练用的 MLM 本身也很有用。比如论文《BERT has a Mouth, and It Must Speak: BERT as a Markov Random Field Language Model》指出 MLM 可以作为一般的生成模型用,论文《Spelling Error Correction with Soft-Masked BERT》则将 MLM 用于文本纠错,笔者之前在《从语言模型到Seq2Seq:Transformer如戏,全靠Mask》的实验也表明 MLM 的预训练权重也可以当作 UniLM 来用做 Seq2Seq 任务,还有《无监督分词和句法分析!原来BERT还可以这样用》一文将 MLM 的思想用于无监督分词和句法分析了。可以说 MLM 已经是大放异彩了。

将任务转成完形填空

在本文里,我们再学习 MLM 的一个精彩应用:用于小样本学习或半监督学习,某些场景下甚至能做到零样本学习。

怎么将我们要做的任务跟 MLM 结合起来呢?很简单,给任务一个文本描述,然后转换为完形填空问题即可。举个例子,假如给定句子“这趟北京之旅我感觉很不错。”,那么我们补充个描述,构建如下的完形填空:______满意。这趟北京之旅我感觉很不错。

进一步地,我们限制空位处只能填一个“很”或“不”,问题就很清晰了,就是要我们根据上下文一致性判断是否满意,如果“很”的概率大于“不”的概率,说明是正面情感倾向,否则就是负面的,这样我们就将情感分类问题转换为一个完形填空问题了,它可以用 MLM 模型给出预测结果,而 MLM 模型的训练可以不需要监督数据,因此理论上这能够实现零样本学习了。

多分类问题也可以做类似转换,比如新闻主题分类,输入句子为“八个月了,终于又能在赛场上看到女排姑娘们了。”,那么就可以构建下面报导一则______新闻。八个月了,终于又能在赛场上看到女排姑娘们了。

这样我们就将新闻主题分类也转换为完形填空问题了,一个好的 MLM 模型应当能预测出“体育”二字来。

还有一些简单的推理任务也可以做这样的转换,常见的是给定两个句子,判断这两个句子是否相容,比如“我去了北京”跟“我去了上海”就是矛盾的,“我去了北京”跟“我在天安门广场”是相容的,常见的做法就是将两个句子拼接起来输入到模型做,作为一个二分类任务。如果要转换为完形填空,那该怎么构造呢?一种比较自然的构建方式是:我去了北京?______,我去了上海。
我去了北京?______,我在天安门广场。

其中空位之处的候选词为 是的,不是是的,不是。

Pattern-Exploiting

读到这里,读者应该不难发现其中的规律了,就是给输入的文本增加一个前缀或者后缀描述,并且 Mask 掉某些 Token,转换为完形填空问题,这样的转换在原论文中称为 Pattern,这个转换要尽可能与原来的句子组成一句自然的话,不能过于生硬,因为预训练的 MLM 模型就是在自然语言上进行的。显然同一个问题可以有很多不同的 Pattern,比如情感分类的例子,描述可以放最后,变成“这趟北京之旅我感觉很不错。__满意。”;也可以多加几个字,比如“觉得如何?__满意。这趟北京之旅我感觉很不错。”。

然后,我们需要构建预测 Token 的候选空间,并且建立 Token 到实际类别的映射,这在原论文中称为 Verbalizer,比如情感分类的例子,我们的候选空间是 很,不很,不,映射关系是 很→正面,不→负面很→正面,不→负面,候选空间与实际类别之间不一定是一一映射,比如我们还可以加入“挺”、“太”、“难”字,并且认为 很,挺,太→正面很,挺,太→正面 以及 不,难→负面不,难→负面,等等。不难理解,不少 NLP 任务都有可能进行这种转换,但显然这种转换一般只适用于候选空间有限的任务,说白了就是只用来做选择题,常见任务的就是文本分类。

刚才说了,同一个任务可以有多种不同的Pattern,原论文是这样处理的:

1、对于每种 Pattern,单独用训练集 Finetune 一个 MLM 模型出来;
2、然后将不同 Pattern 对应的模型进行集成,得到融合模型;
3、用融合模型预测未标注数据的伪标签;
4、用伪标签数据 Finetune 一个常规的(非 MLM 的)模型。

具体的集成方式大家自己看论文就行,这不是重点。这种训练模式被称为 Pattern-Exploiting Training (PET),它首先出现在论文《Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference》,本文要介绍的这篇论文则进一步肯定和完善了 Pattern-Exploiting Training 的价值和结果,并整合了多任务学习,使得它在 SuperGLUE 榜单上的小样本学习效果超过了 GPT3。两篇论文的作者是相同的,是一脉相承的作品。

PET 在 SuperGLUE 上的小样本学习的结果

不过要吐槽一个点是,上图中 PET 的 223M 参数,所用的模型是 ALBERT-xxlarge-v2,事实上称 ALBERT 为“小模型”是一种很耍流氓的行为,因为它前向计算的速度并没有得到任何提升。ALBERT-xxlarge 共有 12 层,层与层之间参数是共享的,就前向计算而言,它应该等价于约 2700M(12 倍)参数的 GPT 才对。

PET 中文实践,检验效果

要真正确认一个方法或模型的价值,看论文的实验表格是不够的,论文给出的实验结果谁都不好说能否复现,其次就算英文上能复现也不代表中文上有s价值,因此最实际的还是亲自动手做实验验证。下面是笔者的实验代码,供读者参考:Github地址:https://github.com/bojone/Pattern-Exploiting-Training

我们将从以下几个角度来探讨 PET 的可行性:

1、直接利用现成的 MLM 模型效果如何?(零样本学习1)
2、用“大量无标签数据”微调现成的 MLM 模型效果如何?(零样本学习2)
3、用“小量标签数据”微调现成的 MLM 模型效果如何?(小样本学习)
4、用“小量标签数据+大量无标签数据”微调现成的 MLM 模型效果如何?(半监督学习)

下面主要给出情感二分类的实验结果。另外还有一个新闻主题的多分类,代码也放到 Github 了,其结果是类似的,就不重复陈述了。

零样本学习1

这里主要探索的是给输入文本补上对应的 Pattern 后,直接基于现成的 MLM 模型进行预测,预测的准确率。由于构建模型的整个过程都不涉及到标签数据监督训练,因此这算是一种“零样本学习”。我们需要比较的是不同 Pattern、不同 MLM 模型上的效果:

下面是实验的几个 Pattern,其中空位处候选词语都为“很”和“不”:

P1:____满意。这趟北京之旅我感觉很不错。
P2:这趟北京之旅我感觉很不错。____满意。
P3:____好。这趟北京之旅我感觉很不错。
P4:____理想。这趟北京之旅我感觉很不错。
P5:感觉如何?____满意。这趟北京之旅我感觉很不错。

至于 MLM 模型,则是下面几个:

M1:Google 开源的中文版 BERT Base(链接);

M2:哈工大开源的 RoBERTa-wwm-ext Base(链接):

M3:腾讯 UER 开源的 BERT Base(链接);

M4:腾讯 UER 开源的BERT Large(链接)。

实验结果如下表(验证集/测试集):不同模型不同Pattern的零样本学习效果

最好的效果居然可以达到 88%!也就是说,加载现成的 MLM,配合适当的 Pattern,不需要任何标注数据,就可以正确识别大部分样本情感倾向了。这不得不让我们对 MLM 模型的潜力刮目相看了。

可以观察到,不同的 Pattern、不同的预训练模型之间还是有一定的差异的,整体而言 Large 版本的效果要明显好于 Base 版本的模型,说明像 GPT 到 GPT2 再到 GPT3 一样,还是把模型做得更大会更好。此外,这还有可能说明实际上 MLM 还没有被充分训练好,或许是因为 BERT 这种 Mask 掉一部分的训练方式过于低效了,可能用《修改Transformer结构,设计一个更快更好的MLM模型》一文提到的改进版 MLM 会更好。

零样本学习2

看完上述结果,读者可能会想到:如果我用领域内的数据继续预训练 MLM 模型,那么能不能提升效果呢?答案是:能!下面是我们的实验结果,算力有限,我们只在 RoBERTa-wwm-ext(上述的 M2,继续预训练后的模型我们称为 M2+无监督M2+无监督)的基础上做了比较:

要注意的是,这里我们只是用领域内的数据继续做 MLM 训练,这个过程是无监督的,也不需要标注信号,因此也算是“零样本学习”。同时,从到目前为止的结果我们可以看出,给输入本文加入“前缀”的效果比“后缀”更有优势一些。

小样本学习

刚才我们讨论了无标签数据继续预训练 MLM 的提升,如果回到 PET 的目标场景,直接用小量的标签数据配合特定的 Pattern 训练 MLM 又如何呢?这也就是真正的“小样本学习”训练了,这里我们保留约 200 个标注样本,构造样本的时候,我们先给每个句子补上 Pattern,除了 Pattern 自带的 Mask 位置之外,我们还随机 Mask 其他一部分,以增强对模型的正则。最终实验结果如下:

结论就是除了“后缀式”的 P2 之外,其它结果都差不多,这进一步说明了“前缀式”的 Pattern 会比“后缀式”更有竞争力一些。在效果上,直接用同样的数据用常规的方法去微调一个 BERT 模型,大概的结果是 88.93 左右,所以基于“MLM+Pattern”的小样本学习方法可能带来轻微的性能提升。

半监督学习

无监督的零样本学习和有监督的小样本学习都说完了,自然就轮到把标注数据和非标注数据都结合起来的“半监督学习”了。还是同样的任务,标注数据和非标注数据的比例大约是 1:99,标注数据带 Pattern,非标注数据不带 Pattern,大家都 Mask 掉一部分 Token 进行 MLM 预训练,最终测出来的效果如下:

还是同样的,“后缀”明显比“前缀”差,“前缀”的效果差不多。具体效果上,则是肯定了额外的无标注数据也是有作用的。直觉上来看,“前缀”比“后缀”要好,大体上是因为“前缀”的 Mask 位置比较固定,微弱的监督信号得以叠加增强?但这也不能解释为什么零样本学习的情况下也是“前缀”更好,估计还跟模型的学习难度有关系,可能句子前面部分的规律更加明显,相对来说更加容易学一些,所以前面部分就学习得更加充分?这一切都还只是猜测。

汇总与结论

将上述结果汇总如下:结果汇总比较

读者还可以对比我们之前在文章《泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练》中用虚拟对抗训练 (VAT) 做半监督学习的结果,可以看到不管是零样本学习、小样本学习还是半监督学习,基于 MLM 模型的方式都能媲美基于 VAT 的半监督学习的结果。我们在做短新闻多分类实验时的结果也是相似的。因此,这说明了 MLM 模型确实也可以作为一个优秀的零样本/小样本/半监督学习器来使用。

当然,基于 MLM 模型的缺点还是有的,比如 MLM 所使用的独立假设限制了它对更长文本的预测能力(说白了空位处的文字不能太长),以及无法预测不定长的答案也约束了它的场景(所以当前只能用于做选择题,不能做生成)。我们期待有更强的 MLM 模型出现,那时候就有可能在所有任务上都能与 GPT3 一较高下了。

什么是模版

前面介绍的 Pattern-Exploiting Training (PET) 方法,其主要的思想是借助由自然语言构成的模版(英文常称 Pattern 或 Prompt),将下游任务也转化为一个完形填空任务,这样就可以用 BERT 的 MLM 模型来进行预测了。比如下图中通过条件前缀来实现情感分类和主题分类的例子:

通过特定模版将情感分类转换为 MLM 任务

通过特定模版将新闻分类转换为 MLM 任务

当然,这种方案也不是只有 MLM 模型可行,用 GPT 这样的单向语言模型(LM)其实也很简单:

通过特定模版将情感分类转换为 LM 任务

通过特定模版将新闻分类转换为 LM 任务

不过由于语言模型是从左往右解码的,因此预测部分只能放在句末了(但还可以往补充前缀说明,只不过预测部分放在最后)。

某种意义上来说,这些模版属于语言模型的“探针”,我们可以通过模版来抽取语言模型的特定知识,从而做到不错的零样本效果,而配合少量标注样本,可以进一步提升效果。

然而,对于某些任务而言,人工构建模版并不是那么容易的事情,模型的优劣我们也不好把握,而不同模型之间的效果差别可能很大,在这种情况下,人工标注一些样本可能比构建模版还要轻松得多。所以,如何根据已有的标注样本来自动构建模版,便成了一个值得研究的问题了。

P-tuning

最近 Arxiv 上的论文《GPT Understands, Too》提出了名为 P-tuning 的方法,成功地实现了模版的自动构建。不仅如此,借助 P-tuning,GPT 在 SuperGLUE 上的成绩首次超过了同等级别的 BERT 模型,这颠覆了一直以来“GPT 不擅长 NLU”的结论,也是该论文命名的缘由。

P-tuning 重新审视了关于模版的定义,放弃了“模版由自然语言构成”这一常规要求,从而将模版的构建转化为连续参数优化问题,虽然简单,但却有效。

模版的反思

首先,我们来想一下“什么是模版”。直观来看,模版就是由自然语言构成的前缀/后缀,通过这些模版我们使得下游任务跟预训练任务一致,这样才能更加充分地利用原始预训练模型,起到更好的零样本、小样本学习效果。

等等,我们真的在乎模版是不是“自然语言”构成的吗?

并不是。本质上来说,我们并不关心模版长什么样,我们只需要知道模版由哪些 token 组成,该插入到哪里,插入后能不能完成我们的下游任务,输出的候选空间是什么。模版是不是自然语言组成的,对我们根本没影响,“自然语言”的要求,只是为了更好地实现“一致性”,但不是必须的。于是,P-tuning 考虑了如下形式的模版:

P-tuning 直接使用 [unused*] 的 token 来构建模版,不关心模版的自然语言性

这里的 [u1]~[u6],代表 BERT 词表里边的 [unused1]~[unused6],也就是用几个从未见过的 token 来构成模板,这里的 token 数目是一个超参数,放在前面还是后面也可以调整。接着,为了让“模版”发挥作用,我们用标注数据来求出这个模板。

如何去优化

这时候,根据标注数据量的多少,我们又分两种情况讨论。

第一种,标注数据比较少。这种情况下,我们固定整个模型的权重,只优化 [unused1]~[unused6] 这几个 token 的 Embedding,换句话说,其实我们就是要学 6 个新的 Embedding,使得它起到了模版的作用。这样一来,因为模型权重几乎都被固定住了,训练起来很快,而且因为要学习的参数很少,因此哪怕标注样本很少,也能把模版学出来,不容易过拟合。

第二种,标注数据很充足。这时候如果还按照第一种的方案来,就会出现欠拟合的情况,因为只有 6 个 token 的可优化参数实在是太少了。因此,我们可以放开所有权重微调,原论文在 SuperGLUE 上的实验就是这样做的。读者可能会想:这样跟直接加个全连接微调有什么区别?原论文的结果是这样做效果更好,可能还是因为跟预训练任务更一致了吧。

P-tuning 在 SuperGLUE 上的表现

此外,在上面的例子中,目标 token 如“很”、“体育”是认为选定的,那么它们可不可以也用 [unused*] 的 token 代替呢?答案是可以,但也分两种情况考虑:1、在标注数据比较少的时候,人工来选定适当的目标 token 效果往往更好些;2、在标注数据很充足的情况下,目标 token 用 [unused*] 效果更好些,因为这时候模型的优化空间更大一些。

增强相关性

在原论文中,P-tuning 并不是随机初始化几个新 token 然后直接训练的,而是通过一个小型的 LSTM 模型把这几个 Embedding 算出来,并且将这个 LSTM 模型设为可学习的。这样多绕了一步有什么好处呢?原论文大概的意思是:LSTM 出现的 token 表示相关性更强,某种程度上来说更像“自然语言”(因为自然语言的 token 之间不是独立的),此外还能防止局部最优。我在 Github 上进一步向作者确认了一下(参考这里),效果上的差别是通过 LSTM 多绕一步的方法可以使得模型收敛更快、效果更优。

然而,这样多了一个LSTM,总感觉有些别扭,而且实现上也略微有点麻烦。按照作者的意思,LSTM 是为了帮助模版的几个 token(某种程度上)更贴近自然语言,但这并不一定要用 LSTM 生成,而且就算用 LSTM 生成也不一定达到这一点。笔者认为,更自然的方法是在训练下游任务的时候,不仅仅预测下游任务的目标 token(前面例子中的“很”、“新闻”),还应该同时做其他 token 的预测

比如,如果是 MLM 模型,那么也随机 mask 掉其他的一些 token 来预测;如果是 LM 模型,则预测完整的序列,而不单单是目标词。这样做的理由是:因为我们的 MLM/LM 都是经过自然语言预训练的,所以我们(迷之自信地)认为能够很好完成重构的序列必然也是接近于自然语言的,因此这样增加训练目标,也能起到让模型更贴近自然语言的效果。经过笔者的测试,加上这样辅助目标,相比单纯优化下游任务的目标,确实提升了效果。

P-tuning 实验与效果

所谓“talk is cheap, show me the code”,又到了喜闻乐见的实验时间了。这里分享一下 P-tuning 的实验结果,其中还包括笔者对 P-tuning 的实现思路,以及笔者在中文任务上的实验结果。

停止的梯度

怎么实现上述的 P-tuning 算法比较好呢?如果是放开所有权重训练,那自然是简单的,跟普通的 BERT 微调没有什么区别。关键是在小样本场景下,如何实现“只优化几个 token”呢?

当然,实现的方法也不少,比如为那几个要优化的 token 重新构建一个 Embedding 层,然后拼接到 BERT 的Embedding层中,然后训练的时候只放开新 Embedding 层的权重。但这样写对原来模型的改动还是蛮大的,最好的方法是尽可能少改动代码,让使用者几乎无感。为此,笔者构思了一种用 stop_gradient 简单修改 Embedding 层的方案,大体上是将 Embedding 层修改如下:

class PtuningEmbedding(Embedding):
    """新定义Embedding层,只优化部分Token
    """
    def call(self, inputs, mode='embedding'):
        embeddings = self.embeddings
        embeddings_sg = K.stop_gradient(embeddings)
        mask = np.zeros((K.int_shape(embeddings)[0], 1))
        mask[1:9] += 1  # 只优化id为1~8的token
        self.embeddings = embeddings * mask + embeddings_sg * (1 - mask)
        return super(PtuningEmbedding, self).call(inputs, mode)

变量经过 stop_gradient 算子后,在反向传播的时候梯度为 0,但是前向传播不变,因此在上述代码中,前向传播的结果不会有变化,但是反向传播求梯度的时候,梯度不为 0 的 token 由 mask 变量控制,其余 token 的梯度都为零,因此就实现了只更新部分 token。

完整代码可见:Github:https://github.com/bojone/P-tuning

对了,原论文也开源了代码:Github:https://github.com/THUDM/P-tuning

测试与效果

前面已经分享了原作者在 SuperGLUE 上的实验结果,显示出如果配合 P-tuning,那么:1、GPT、BERT 的效果相比直接 finetune 都有所提升;2、GPT 的效果还能超过了 BERT。这表明 GPT 不仅有 NLG 的能力,也有 NLU 能力,可谓是把 GPT 的潜能充分“压榨”出来了,当然 BERT 配合 P-tuning 也有提升,说明 P-tuning 对语言模型潜能的释放是较为通用的。

原论文的实验比较丰富,建议读者仔细阅读原论文,相信会收获颇多。特别指出的是原论文的 Table 2 最后一列,当预训练模型足够大的时候,我们的设备可能无法 finetune 整个模型,而 P-tuning 可以选择只优化几个 Token 的参数,因为优化所需要的显存和算力都会大大减少,所以 P-tuning 实则上给了我们一种在有限算力下调用大型预训练模型的思路

P-tuning 在各个体量的语言模型下的效果

当然,笔者一直以来的观点是“没有在中文上测试过的算法是没有灵魂的”,因此笔者也在中文任务上简单测试了,测试任务跟前文一致,都是情感分类的小样本学习,测试模型包括 BERT 和 GPT,两者的候选模版分别如下图:

笔者在中文情感分类上使用的“BERT+P-tuning”模版

笔者在中文情感分类上使用的“GPT+P-tuning”模版

注意,对于 LM 模型,前缀的引入非常重要,只引入后缀时效果会明显变差;而对于 MLM 模型,前缀的效果通常也优于后缀。总的效果如下表:

其中“小样本”只用到了“少量标注样本”,“无监督”则用到了“大量无标注样本”,“半监督”则用到了“少量标注样本+大量无标注样本”,“P-tuning” 都是小样本,PET 的几个任务报告的是最优的人工模版的结果,其实还有更差的人工模版。从小样本角度来看,P-tuning 确实取得了最优的小样本学习效果;从模版构建的角度来看,P-tuning 确实也比人工构建的模版要好得多;从模型角度看,P-tuning 确实可以将 GPT 的分类性能发挥到跟 BERT 相近,从而揭示了 GPT 也有很强的 NLU 能力的事实。

进一步理解 P-tuning

这一节将会介绍笔者对 P-tuning 的进一步思考,以求从多个维度来理解 P-tuning。

离散 vs 连续

在 P-tuning 之前,也已经有一些在做模版的自动构建,如《How Can We Know What Language Models Know?》《AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts》等,但它们搜索的都是在离散空间下搜索的自然语言模版,所以效果有所限制,并没有取得特别突出的结果。

相反,P-tuning 放弃了“模版由自然语言构成”这一要求,从而将其变成了可以简单梯度下降求解的连续参数问题,效果还更好。同时,这一改动意味着 P-tuning 突出了模版的本质——即模版的关键在于它是怎么用的,不在于它由什么构成——给人一种去芜存菁、眼前一亮的感觉,确实值得点赞。

(注:经读者@brotherb提醒,年初有一篇论文《Prefix-Tuning: Optimizing Continuous Prompts for Generation》提出的 Prefix-Tuning 方法其实已经相当接近 P-tuning,两者都设计了非自然语言的模版,只不过 Prefix-Tuning 主要关心 NLG 的应用而 P-tuning 更加关心 NLU 的应用。)

Adapter

我们还可以从 Adapter 的角度来理解 P-tuning。BERT 出来后不久,Google 在论文《Parameter-Efficient Transfer Learning for NLP》中提出了一种名为 Adapter 的微调方式,它并不是直接微调整个模型,而是固定住 BERT 原始权重,然后在 BERT 的基础上添加一些残差模块,只优化这些残差模块,由于残差模块的参数更少,因此微调成本更低。Adapter 的思路实际上来源于 CV 的《Learning multiple visual domains with residual adapters》,不过这两年似乎很少看到了,也许是因为它虽然提高了训练速度,但是预测速度却降低了,精度往往还有所损失。

在 P-tuning 中,如果我们不将新插入的 token 视为“模版”,是将它视为模型的一部分,那么实际上 P-tuning 也是一种类似 Adapter 的做法,同样是固定原模型的权重,然后插入一些新的可优化参数,同样是只优化这些新参数,只不过这时候新参数插入的是 Embedding 层。因此,从这个角度看,P-tuning 与 Adapter 有颇多异曲同工之处。

为什么有效

然后,还有一个值得思考的问题:为什么 P-tuning 会更好?比如全量数据下,大家都是放开所有权重,P-tuning 的方法依然比直接 finetune 要好,为啥呢?

事实上,提出这个问题的读者,应该是对 BERT 加个全连接层的直接 finetune 做法“习以为常”了。很明显,不管是 PET 还是 P-tuning,它们其实都更接近预训练任务,而加个全连接层的做法,其实还没那么接近预训练任务,所以某种程度上来说,P-tuning 有效更加“显然”,反而是加个全连接层微调为什么会有效才是值得疑问的。

去年有篇论文《A Mathematical Exploration of Why Language Models Help Solve Downstream Tasks》试图回答这个问题,大致的论证顺序是:

1、预训练模型是某种语言模型任务;
2、下游任务可以表示为该种语言模型的某个特殊情形;
3、当输出空间有限的时候,它又近似于加一个全连接层;
4、所以加一个全连接层微调是有效的。

可以看到,该论文的假设主要是第 2 点,其实就是直接假设了下游任务可以表达为类似 PET 的形式,然后才去证明的。所以这进一步说明了,PET、P-tuning 等才是更自然的使用预训练模型的方式,加全连接直接 finetune 的做法其实只是它们的推论罢了,也就是说,PET、P-tuning 才是返璞归真、回归本质的方案,所以它们更有效。

转载自《必须要GPT3吗?不,BERT的MLM模型也能小样本学习》《P-tuning:自动构建模版,释放语言模型潜能》,作者:苏剑林,部分内容有修改。

Prompt Learning(模板学习)

论文:https://arxiv.org/pdf/2107.13586.pdf

“Prompt:NLP 新范式”

“Pre-train, Prompt, and Predict” —- Prompt可以认为就是下游任务来适应预训练模型而做的微调 (所需数据量少、训练快、效果好),原始的微调是让预训练模型来适应下游任务。

文章摘自:未闻 Prompt 名

个人觉得 2021 年 NLP 最火的两个 idea,一个是对比学习(Contrastive Learning),另一个就是 Prompt

浅谈我对 Prompt 的理解

Prompt 说简单也简单,看了几篇论文以及博客后发现其实就是构建一个语言模版。但是细想起来又觉得复杂,因为总感觉里面还有很多细节,因此本文就来从头梳理一下 Prompt(Prompt 很多地方会翻译成「范式」,但是「范式」这个词本身也不好理解,因此读者把他看作是「模板」即可)

今天我还与室友讨论预训练模型(例如 BERT)到底做了什么,我给出的回答是

预训练模型提供了一个非常好的初始化参数,这组参数在预训练任务上的表现非常好(预训练损失非常低),但是由于下游任务千奇百怪,我们需要在这组参数的基础上进行 Fine-tune 以适应我们的下游任务(使得下游任务的损失值非常低)

上面这段话其实隐含了目前做 NLP 任务的大致流程,即 “Pre-train, Fine-tune”,而对我们来说实际上大部分时候都是直接拿别人预训练好的模型做 Fine-tune,并没有 Pre-train 这一步

融入了 Prompt 的模式大致可以归纳成 “Pre-train, Prompt, and Predict”,在该模式中,下游任务被重新调整成类似预训练任务的形式。例如,通常的预训练任务有 MLM(Masked Language Model),在文本情感分类任务中,对于 “I love this movie” 这句输入,可以在后面加上 Prompt:”the movie is ___”,组成如下这样一句话:

I love this movie, the movie is ___

然后让预训练模型用表示情感的答案(例如 “great”、”terrible” 等)做完形填空,最后再将该答案转换为情感分类的标签。这样一来,我们就可以通过构造合适的「模板」,通过小样本数据集训练一个模型来解决各种各样的下游任务

注意,Prompt 设计的这种完形填空和 MLM(Masked Language Modeling) 任务是有区别的,二者虽然都是都是词分类,但是候选集不同,MLM 的候选词是整个词库,不过如果是生成任务,那么 Prompt 和 MLM 的候选集就是一样的,都是整个词库

如何构建 Prompt

对于输入文本 x,存在一个函数 fPrompt(x),将 x 转化成 x′ 的形式,即

该函数通常会进行两步操作:

  1. 使用一个模板,模板通常为一段自然语言句子,并且该句子包含两个空位置:用于填输入 x 的位置 [X]、用于生成答案文本 z 的位置 [Z]
  2. 把输入 x 填到 [X] 的位置

以前文提到的例子为例,在文本情感分类任务中,假设输入是

x = "I love this movie"

使用的模板是

[X]. Overall, it was a [Z] movie

那么得到的 x′ 就应该是

I love this movie. Overall, it was a [Z] movie

在实际情况中,Prompt 来填充答案的位置一般在句中或句末。如果在句中,一般称这种 Prompt 为 Cloze Prompt;如果在句末,一般称这种 Prompt 为 Prefix Prompt。[X] 和 [Z] 的位置、数量以及使用模板句的不同,都有可能对结果造成影响,因此需要灵活调整

上面讲的都是简单的情感分类任务的 Prompt 设计,读者看到这里自然而然的会想到,其他 NLP 任务的 Prompt 如何设计呢?实际上刘鹏飞大神在他的论文中给我们提供了一些参考

Text Generation 中摘要任务里有一个关键字 TL;DR,这其实是 Too Long; Don't Read 的缩写

Prompt 的选择非常重要且困难

有上述 Prompt 的基础后,我们可以得知 Prompt 的设计主要包含两部分:

  1. 模板 T:例如 [X]. Overall, It was [Z]
  2. 标签词映射:即 [Z] 位置预测输出的词汇集合与真实标签 y 构成的映射关系。例如,标签 positive 对应单词 great,标签 negative 对应单词 terrible

基于 Prompt 的微调方法中,不同的模板和标签词对最终结果影响很大,下图是陈丹琦团队论文中的实验结果

从上图我们可以看出两点:

  1. 使用相同的「模板」,不同的「标签词」会产生不一样的效果。例如 great/terribel 和 cat/dog 这两组标签词的效果不一样,而且即便是相同标签词,互换顺序也会导致最终效果有所变化,例如 cat/dog 和 dot/cat
  2. 使用相同「标签词」,对「模板」进行小改动(例如增删标点)也会呈现不同的结果

Prompt 的设计

Prompt 大概可以从下面三个角度进行设计:

  • Prompt 的形状
  • 人工设计模板
  • 自动学习模板

Prompt 的形状

Prompt 的形状主要指的是 [X] 和 [Z] 的位置和数量。上文提到的 Cloze Prompt 与 Maksed Language Model 的训练方式非常类似,因此对于 MLM 任务来说,Cloze Prompt 更合适;对于生成任务或者使用自回归 LM 解决的任务,Prefix Prompt 更合适。

人工设计模板

Prompt 的模板最开始是人工设计的,人工设计一般基于人类的自然语言知识,力求得到语义流畅且高效的「模板」。例如,Petroni 等人在著名的 LAMA 数据集中为知识探针任务人工设计了 Cloze Templates;Brown 等人为问答、翻译和探针等任务设计了 Prefix Templates。人工设计模板的优点是直观,但缺点是需要很多实验、经验以及语言专业知识。下图是 GPT Understands, Too 论文中的一个实验结果

可以看到不同的 Prompt 只有细微的区别,有的甚至只是增加减少一个词,但是最后的结果会差几十个点

自动学习模板

为了解决人工设计模板的缺点,许多研究员开始探究如何自动学习到合适的模板。自动学习的模板又可以分为离散(Discrete Prompts)和连续(Continuous Prompts)两大类。离散方法主要包括:Prompt Mining,Prompt Paraphrasing,Gradient-based SearchPrompt Generation 和 Prompt Scoring;连续的则主要包括 Prefix TuningTuning Initialized with Discrete promptsHard-Soft Prompt Hybrid TuningP-Tuning v2

离散 Prompts

简单说一下上述几种方法,首先是离散的 Prompt Mining,这篇文章发表在 TACL 2020,讲的是如何拿预训练语言模型当作「知识库」使用,并且引入了依存树和 Paraphrase(转述)等方法来挖掘更好的「模板」,下图是实验结果

可以看到,被挖掘出来的若干「连接谓词」相比于人工设计的「模板」结果提升还是很明显的

有很多种方法可以实现 Prompt Paraphrsing,例如「回译」,我们通过 DeepL 翻译看个例子:

这样我们就得到了 x shares a border with y 的一个 Prompt Paraphrasing:x and y share a boundary

论文 BARTScore 干脆给我们提供了一张表,里面有各种词组的同义替换,这个我再熟悉不过了,因为以前英语考试我也背过类似的东西

Gradient-based Search(基于梯度的搜索)是由论文 AUTOPROMPT 提出的,这篇文章发表在 EMNLP 2020,它的主要思想用下面这张图就可以表示

上图中,a real joy 是原始的输入句子 xinp,红色的 Trigger tokens 是由 xinp「激发」的相关词汇集合 xtrig,根据 Template λ 的配置,将 xtrig 和 xinp 组合起来构造最终的输入 xprompt,送入 Masked LM 预测情感标签。下面的表格增加了很多 NLP 其他任务的例子

关于如何生成 xtrig 集合,实际上主要使用的是 HotFlip 和对抗训练的思想,感兴趣的同学可以看原论文以及 HotFlip: White-box adversarial examples for text classificationUniversal Adversarial Triggers for Attacking and Analyzing NLP 这两篇论文

Prompt Generation 是陈丹琦团队的一项工作,主要是把 Seq2Seq 预训练模型 T5 应用到模板搜索的过程。T5 基于多种无监督目标进行预训练,其中最有效的一个无监督目标就是:利用 <X> 或 < Y > 替换一个或多个连续 span,然后生成对应输出。例如:

Thank you <X> me to your party <Y> week

T5 会在 <X> 生成 for inviting,在 <Y> 生成 last。很显然,T5 这种方式很适合生成模板,而且不需要指定模板的 token 数。具体来说,有三种可能的生成方式⟨S1⟩→⟨X⟩ M(y) ⟨Y⟩ ⟨S1⟩⟨S1⟩→⟨S1⟩ ⟨X⟩ M(y) ⟨Y⟩⟨S1⟩,⟨S2⟩→⟨S1⟩ ⟨X⟩ M(y) ⟨Y⟩ ⟨S2⟩

具体的模板生成过程如下图所示:

首先在标签词前后添加填充位 <X> 和 < Y>(上面提到的三种生成方式),然后将其送入 T5 模型中,T5 会自动在填充位生成序列,最后将标签词(great 或 terribel)转换为 [MASK] 标签,形成多个模板。具体过程中采用 Beam Search 的方法生成多个候选模板,然后对每一个候选模板利用 dev 集进行微调,选择其中一个最佳模板

我还想说一下这篇论文中另外一个有意思的点,最后送入模型进行预测的句子还拼接上了每种类别的「示例」(Demonstration),如下图所示

这种 Prompt 的设计有点像是在做语义相似度任务,X 为原始 Input 句子,已知 Y 为正例,Z 为负例,构造了如下形式的输入:

X是[MASK]例?Y为正例;Z为负例

这有点像是编程语言中的三目运算符,或者说相当于让模型比较 X 与 Y、Z 的语义相似度。这里我们自然而然会想问:Y、Z 是如何挑选出来的?实际上是依据下面两条规则:

  1. 对于每个原始输入句子,从每个类别中随机采样一个样本「示例」拼接到 Prompt 中
  2. 对于每个原始输入句子,在每个类别中,通过与 Sentence-BERT 进行相似度计算,从相似度最高的前 50% 样本中随机选择一个样本「示例」

连续 Prompts

构造 Prompt 的初衷是能够找到一个合适的方法,让 Pre-trained Language Model(PLM)更好地输出我们想要的结果,但其实并不一定要将 Prompt 的形式设计成人类可以理解的自然语言,只要机器理解就行了。因此,还有一些方法探索连续型 Prompts—— 直接作用到模型的 Embedding 空间。连续型 Prompts 去掉了两个约束条件:

  1. 模版中词语的 Embedding 可以是整个自然语言的 Embedding,不再只是有限的一些 Embedding
  2. 模版的参数不再直接取 PLM 的参数,而是有自己独立的参数,可以通过下游任务的训练数据进行调整

Prefix Tuning 最开始由 Li 等人提出,这是一种在输入句子前添加一组连续型向量的方法,该方法保持 PLM 的参数不动,仅训练前缀(Prefix)向量。Prefix Tuning 的提出主要是为了做生成任务,因此它根据不同的模型结构定义了不同的 Prompt 拼接方式,在 GPT 类的 Auto-Regressive(自回归)模型上采用的是 [Prefix;x;y] 的方式,在 T5 类的 Encoder-Decoder 模型上采用的是 [Prefix;x;Prefix′;y] 的方式

输入部分 Prefix, \(x, y\) 的 Position id 分别记作

\(\mathrm{P}{\mathrm{idx}}\) , \(\mathrm{X}{\mathrm{idx}}\) , \(\mathrm{Y}{\mathrm{idx}}\)。Prefix Tuning 初始化一 个可训练的矩阵,记作 \(P\theta \in \mathbb{R}^{\left|P_{\mathrm{idx}}\right| \times \operatorname{dim}\left(h_i\right)}\) ,其中
\(h_i= \begin{cases}P_\theta[i,:], & \text { if } i \in \mathrm{P}{\mathrm{idx}} \ \mathbf{L M}\phi\left(z_i, h_{<i}\right), & \text { otherwise }\end{cases}\)
上述公式的含义是,索引 $i$ 如果属于前缀的部分,则从 \(P_\theta\) 中抽取向量; \(i\) 如果不是前缀部 分,则由参数固定的预训练模型生成对应的向量。训练目标为:
\(\max \phi \log p\phi(y \mid x)=\sum_{i \in \mathrm{Y}{\mathrm{idx}}} \log p\phi\left(z_i \mid h_{<i}\right)\)

 \(P_\theta\) 本质上是一个矩阵,而生成一个矩阵的方法又很多,可以用 nn.Embedding(),或者 nn.Linear()

同样是在连续空间上搜索 Prompt,OptiPrompt 构建的「模板」并不局限于前缀,也可以在句子的中间

Hard-Soft Prompt Hybrid Tuning 方法可以说是人工设计和自动学习的结合,它通常不单纯使用可学习的 Prompt 模板,而是在人工设计的模板中插入一些可学习的 Embedding。实际上有了上面的基础我们都知道,连续的 Prompt 要比离散的 Prompt 好一点,但是在此基础上还有什么改进的余地吗?Liu 等人提出的 P-Tuning 解决了 Prompt token 之间的关联性问题

之前连续的 Prompt 生成方式无非都是训练一个矩阵,然后通过索引出矩阵的某几行向量拼起来。坦白地说,我们希望这些 prompt token Embedding 之间有一个比较好的关联性,而不是独立地学习,为了解决这个问题,P-Tuning 引入了一个 Prompt Encoder(如下图 b 所示)

上图 a 是传统的离散型 Prompt,我们把生成离散 Prompt token 的东西叫做 Prompt Generator;上图 b 首先传入一些 Virtual(Pseudo)token,例如 BERT 词表中的 [unused1],[unused2],… 当然,这里的 token 数目是一个超参数,插入的位置也可以调整。将这些 Pseudo token 通过一个 Prompt Encoder 得到连续的向量 h0,…,hm,其中

大家可能想问,如何优化 P-tuning?实际上根据标注数据量的多少,分两种情况讨论

  1. 标注数据比较少。这种情况,我们固定 PLM 的参数,只优化 [P0]∼[Pm] 这几个 token 的 Embedding。换句话说,我们只是要更新 Prompt Encoder 的参数
  2. 标注数据很充足。这种情况直接放开所有参数微调

就在 P-Tuning 方法提出不久后,Liu 等人又提出了 P-Tuning v2,主要解决 P-Tuning 的两个问题:

  1. 当预训练模型的参数量低于 100 亿(10B)时,Prompt tuning 会比传统的 Fine-tuning 差
  2. 诸如序列标注这样对推理和理解要求高的任务,prompt tuning 效果会变差

Liu 等人认为先前的 P-Tuning 只用了一层 BiLSTM 来编码 Pseudo token,这是其推理能力不足的原因之一,因此 v2 版本提出 Deep Prompt Tuning,用 Prefix Tuning 中的深层模型替换 BiLSTM,如下图所示

P-Tuning v2 相比于 P-Tuning,区别在于:

  • 取消 Reparameterization:以前的方法利用重参数化功能来提高训练速度和鲁棒性(例如,用于 Prefix-Tuning 的 MLP 和用于 P-Tuning 的 LSTM)。在 P-Tuning v2 中,作者发现重参数化的改进很小,尤其是对于较小的模型,同时还会影响模型的表现
  • Multi-task Learning:Deep Prompt Tuning 的优化难题可以通过增加额外的任务数据或者无标注数据来缓解,同时可微调的 Prefix Continuous Prompt 也可以用来做跨任务的知识共享。例如在 NER 中,可以同时训练多个数据集,不同数据集使用不同的顶层 Classifier,但是 Prefix Continuous Prompt 是共享的
  • 取消 verbalizer:v2 取消了标签映射,完全变为生成模型,可以在 [CLS] 部分输出句子级别的标签(Sentence-level label),也可以在每个 token 位置输出 token 级别的标签(Token-level label),直接输出真实标签

关于 P-Tuning 还有一些碎碎念,主要是从各个博客上看到的,汇总在这里。首先是 v1 版本的 LSTM,实际上引入 LSTM 目的是为了帮助「模板」生成的 token(某种程度上)更贴近自然语言,或者说 token 之间的语义更流畅,但更自然的方法应该是在训练下游任务的时候,不仅预测下游任务的目标 token(例如 “great”、”terrible”),还应该同时做其他 token 的预测

比如,如果是 MLM 模型,那么也随机 MASK 掉其它的一些 token 来预测,如果是 LM 模型,则预测完整的序列,而不单单是目标词。这样做的理由是:因为我们的 MLM/LM 都是经过自然语言预训练的,所以我们认为它能够很好的完成序列的重构,即便一开始不能,随着迭代轮数的增加,模型也能很好完成这项任务。所以这本质上是让模型进行「负重训练」

* 为什么要引入 Prompt?

在标准的 Fine-tune 过程中(如上图 b 所示),新引入的参数量可能会很大(独立于原始预训练模型外的参数),例如基于 RoBERTa-large 的二分类任务会新引入 2048 个参数(nn.Linear(1024, 2)),如果你仅有例如 64 个标注数据这样的小样本数据集,微调会非常困难

为解决这一问题,Prompt 应运而生(如上图 a 所示),直接将下游任务转换为输出空间有限的 MLM 任务。值得注意的是:上述方法在预训练参数的基础上进行微调,并且没有引入任何新参数,同时还减少了微调和预训练任务之间的差距。总的来说,这可以更有效地用于小样本场景

Prompt 的挑战与展望

尽管 Prompt 研究搞得如火如荼,但目前仍存在许多问题值得研究者们去探究

  1. Prompt 的设计问题。目前使用 Prompt 的工作大多集中于分类任务和生成任务,其它任务则较少。另外,「模板」和「答案」的联系也亟待解决。模型的表现同时依赖于使用的「模板」和「答案」的映射,如何同时搜索或者学习出两者联合的最好效果仍然很具挑战性
  2. Prompt 的理论分析和可解释性。尽管 Prompt 方法在很多情况下都取得了成功,但是目前 Prompt-based Learning 理论分析还很少,人们很难了解 Prompt 为什么能达到好的效果,又为什么在自然语言中意义相近的 Prompt 有时效果却相差很大
  3. Prompt 在 PLM debias 方面的应用。由于 PLM 在预训练过程中见过了大量的人类世界的自然语言,所以很自然地会受到一些影响。举一个简单的例子,比如说训练语料中有非常多 “The capital of China is Beijing”,导致模型每次看到 “capital” 的时候都会预测出 “Beijing”,而不是去分析到底是哪个国家的首都。在应用的过程中,Prompt 还暴露了 PLM 学习到的很多其它 bias,比如种族歧视、性别对立等。这也许会是一个值得研究的方向

One More Thing

最后我还想提一个实际 Code 过程中存在的问题。我们知道 MLM 任务会输出句子中 [MASK] 位置最有可能的词,而 Prompt 也类似的,例如下面的例子

这是一条__新闻。中国足球出线的可能性只有0.001%,留给中国队的时间不多了

这是一个新闻分类问题,真实标签有 “体育”、”财经”、”娱乐” 等,上面的样本很明显是一条体育新闻,因此我们希望模型对 [MASK] 部分输出 “体育”,但事实真的如此吗?实际情况模型的输出可能是 “足球”,但你认为模型预测的 “足球” 有问题吗?好像也没啥毛病,因此这就引申出了 Prompt 的一个问题,是否应该限制模型的输出空间?

还是上面新闻分类的例子,我们是否应该限制模型输出的空间,让他固定只能预测 “体育”、”财经”、”娱乐” 这几个标签?或者我们干脆把这几个标签换成索引,那就是让模型从 0,1,2 这三个数字选一个。Wait Wait Wait,如果这么做的话,和 Fine-Tune 有什么区别,Fine-Tune 也是把标签转换成索引,让模型看了句子之后,从这几个索引中选一个作为预测值

这么说的话,那我们就不应该限制模型的输出空间,可是这样的话 [MASK] 位置的输出就限制的太死了,必须一定是 “good”、”财经” 才算对,如果输出 “nice”、”财政” 就算错。实际上输出近义词或者相似词,在零样本的情况下会经常出现,但是如果你用一些有标签的样本去训练,模型自己就会慢慢固定输出空间。例如 “财经”,它不会预测成 “财政”,只会预测成其它类型的新闻,例如 “体育”

References

自然语言处理中注意力机制综述

注意力汇聚

查询(自主提示)和键(非自主提示)之间的交互形成了 注意力汇聚(attentionpooling)。注意力汇聚有选择地聚合了值(感官输入)以生成最终的输
出。注意力汇聚(attention pooling)公式:

其中 x 是查询,(xi; yi) 是键值对。注意力汇聚是 yi 的加权平均。将查询 x 和键 xi之间的关系建模为 注意力权重(attetnion weight) (x; xi),如 (10.2.4) 所示,这个权重将被分配给每一个对应值 yi。对于任何查询,模型在所有键值对上的注意力权重都是一个有效的概率分布:它们是非负数的,并且总和为1。

正如我们所看到的,选择不同的注意力评分函数 a 会导致不同的注意力汇聚操作。

1、加性注意力

2、缩放点积注意力

使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的⻓度d。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为 0 和方差为 1。那么两个向量的点积的均值为 0,方差为 d。为确保无论向量⻓度如何,点积的方差在不考虑向量⻓度的情况下仍然是 1,则可以使用 缩放点积注意力(scaled dot-product attention)评分函数:

1. 写在前面

近些年来,注意力机制一直频繁的出现在目之所及的文献或者博文中,可见在nlp中算得上是个相当流行的概念,事实也证明其在nlp领域散发出不小得作用。这几年得顶会paper就能看出这一点。本文深入浅出地介绍了自然语言处理中的注意力机制技术。据Lilian Weng博主总结以及一些资料显示,Attention机制最早应该是在视觉图像领域提出来的,这方面的工作应该很多,历史也比较悠久。人类的视觉注意力虽然存在很多不同的模型,但它们都基本上归结为给予需要重点关注的目标区域(注意力焦点)更重要的注意力,同时给予周围的图像低的注意力,然后随着时间的推移调整焦点。而直到Bahdanau等人发表了论文《Neural Machine Translation by Jointly Learning to Align and Translate》,该论文使用类似attention的机制在机器翻译任务上将翻译和对齐同时进行,这个工作目前是最被认可为是第一个提出attention机制应用到NLP领域中,值得一提的是,该论文2015年被ICLR录用,截至现在,谷歌引用量为5596,可见后续nlp在这一块的研究火爆程度。

注意力机制首先从人类直觉中得到,在nlp领域的机器翻译任务上首先取得不错的应用成功。简而言之,深度学习中的注意力可以广义地解释为重要性权重的向量:为了预测一个元素,例如句子中的单词,使用注意力向量来估计它与其他元素的相关程度有多强,并将其值的总和作为目标的近似值

既然注意力机制最早在nlp领域应用于机器翻译任务,那在这个之前又是怎么做的呢?传统的基于短语的翻译系统通过将源句分成多个块然后逐个词地翻译它们来完成它们的任务。 这导致了翻译输出的不流畅。想想我们人类是如何翻译的?我们首先会阅读整个待翻译的句子,然后结合上下文理解其含义,最后产生翻译。在某种程度上,神经机器翻译(NMT)的提出正是想去模仿这一过程。而在NMT的翻译模型中经典的做法是由编码器 – 解码器架构制定(encoder-decoder),用作encoder和decoder常用的是循环神经网络。这类模型大概过程是首先将源句子的输入序列送入到编码器中,提取最后隐藏的表示并用于初始化解码器的隐藏状态,然后一个接一个地生成目标单词,这个过程广义上可以理解为不断地将前一个时刻 t-1 的输出作为后一个时刻 t 的输入,循环解码,直到输出停止符为止。通过这种方式,NMT解决了传统的基于短语的方法中的局部翻译问题:它可以捕获语言中的长距离依赖性,并提供更流畅的翻译。但是这样做也存在很多缺点,譬如,RNN是健忘的,这意味着前面的信息在经过多个时间步骤传播后会被逐渐消弱乃至消失。其次,在解码期间没有进行对齐操作,因此在解码每个元素的过程中,焦点分散在整个序列中。对于前面那个问题,LSTM、GRU在一定程度能够缓解。而后者正是Bahdanau等人重视的问题。

 2、NLP中Attention mechanism的起源

在Seq2Seq结构中,encoder把所有的输入序列都编码成一个统一的语义向量context,然后再由decoder解码。而context自然也就成了限制模型性能的瓶颈。譬如机器翻译问题,当要翻译的句子较长时,一个context可能存不下那么多信息。除此之外,只用编码器的最后一个隐藏层状态,感觉上都不是很合理。实际上当我们翻译的时候譬如:Source:机器学习–>Target:machine learning。当decoder要生成”machine”的时候,应该更关注”机器”,而生成”learning”的时候,应该给予”学习”更大的权重。所以如果要改进Seq2Seq结构,一个不错的想法自然就是利用encoder所有隐藏层状态解决context限制问题。

Bahdanau等人把attention机制用到了神经网络机器翻译(NMT)上。传统的encoder-decoder模型通过encoder将Source序列编码到一个固定维度的中间语义向量context,然后在使用decoder进行解码翻译到目标语言序列。前面谈到了这种做法的局限性,而且,Bahdanau等人在摘要也说到这个context可能是提高这种基本编码器 – 解码器架构性能的瓶颈,那Bahdanau等人又是如何尝试缓解这个问题的呢?让我们来一探究竟,作者为了缓解中间向量context很难将Source序列所有必要信息压缩进来的问题,特别是对于那些很长的句子。提出在机器翻译任务上在 encoder–decoder 做出了如下扩展:将翻译和对齐联合学习。这个操作在生成Target序列的每个词时,用到的中间语义向量context是Source序列通过encoder的隐藏层的加权和,而传统的做法是只用encoder最后一个输出 ht 作为context,这样就能保证在解码不同词的时候,Source序列对现在解码词的贡献是不一样的。想想前面那个例子:”Source:机器学习–>Target:machine learning”(假如中文按照字切分)。decoder在解码”machine”时,”机”和”器”提供的权重要更大一些,同样,在解码”learning”时,”学”和”习”提供的权重相应的会更大一些,这在直觉也和人类翻译也是一致的。通过这种attention的设计,作者将Source序列的每个词(通过encoder的隐藏层输出)和Target序列(当前要翻译的词)的每个词巧妙的建立了联系。想一想,翻译每个词的时候,都有一个语义向量,而这个语义向量是Source序列每个词通过encoder之后的隐藏层的加权和。 由此可以得到一个Source序列和Target序列的对齐矩阵,通过可视化这个矩阵,可以看出在翻译一个词的时候,Source序列的每个词对当前要翻译词的重要性分布,这在直觉上也能给人一种可解释性的感觉。

3. NLP中的注意力机制

随着注意力机制的广泛应用,在某种程度上缓解了源序列和目标序列由于距离限制而难以建模依赖关系的问题。现在已经涌现出了一大批基于基本形式的注意力的不同变体来处理更复杂的任务。让我们一起来看看其在不同NLP问题中的注意力机制。

其实我们可能已经意识到了,对齐模型的设计不是唯一的,确实,在某种意义上说,根据不同的任务设计适应于特定任务的对齐模型可以看作设计出了新的attention变体,让我们再看看这个模型(函数): score(st,hi) 。再来看几个代表性的work。

Citation等人提出Content-base attention,其对齐函数模型设计为:

Bahdanau等人的Additive(*),其设计为:

Luong[4]等人文献包含了几种方式:

以及Luong等人还尝试过location-based function。这种方法的对齐分数仅从目标隐藏状态学习得到。

  • Vaswani[6]等人的Scaled Dot-Product(^)缩放点积注意:

细心的童鞋可能早就发现了这东东和点积注意力很像,只是加了个scale factor。当输入较大时,softmax函数可能具有极小的梯度,难以有效学习,所以作者加入比例因子 1/n 。

Cheng[7]等人的Self-Attention(&)可以关联相同输入序列的不同位置。 从理论上讲,Self-Attention可以采用上面的任何 score functions。在一些文章中也称为“intra-attention” 

Hu[7]对此分了个类:

前面谈到的一些Basic Attention给人的感觉能够从序列中根据权重分布提取重要元素。而Multi-dimensional Attention能够捕获不同表示空间中的term之间的多个交互,这一点简单的实现可以通过直接将多个单维表示堆叠在一起构建。Wang[8]等人提出了coupled multi-layer attentions,该模型属于多层注意力网络模型。作者称,通过这种多层方式,该模型可以进一步利用术语之间的间接关系,以获得更精确的信息。

3.1 Hierarchical(层次) Attention

再来看看Hierarchical Attention,Yang[9]等人提出了Hierarchical Attention Networks,看下面的图可能会更直观:

Hierarchical Attention Networks

这种结构能够反映文档的层次结构。模型在单词和句子级别分别设计了两个不同级别的注意力机制,这样做能够在构建文档表示时区别地对待这些内容。Hierarchical attention可以相应地构建分层注意力,自下而上(即,词级到句子级)或自上而下(词级到字符级),以提取全局和本地的重要信息。自下而上的方法上面刚谈完。那么自上而下又是如何做的呢?让我们看看Ji[10]等人的模型:

Nested Attention Hybrid Model

和机器翻译类似,作者依旧采用encoder-decoder架构,然后用word-level attention对全局语法和流畅性纠错,设计character-level attention对本地拼写错误纠正。

3.2 Self-Attention

那Self-Attention又是指什么呢?

Self-Attention(自注意力),也称为”intra-attention”(内部注意力),是关联单个序列的不同位置的注意力机制,以便计算序列的交互表示。 它已被证明在很多领域十分有效比如机器阅读,文本摘要或图像描述生成。

  • 比如Cheng[11]等人在机器阅读里面利用了自注意力。当前单词为红色,蓝色阴影的大小表示激活程度,自注意力机制使得能够学习当前单词和句子前一部分词之间的相关性。

当前单词为红色,蓝色阴影的大小表示激活程度

  • 比如Xu[12]等人利用自注意力在图像描述生成任务。注意力权重的可视化清楚地表明了模型关注的图像的哪些区域以便输出某个单词。

我们假设序列元素为 V=vi ,其匹配向量为 u 。让我们再来回顾下前面说的基本注意力的对齐函数,attention score通过 a(u,vi) 计算得到,由于是通过将外部 u 与每个元素 vi 匹配来计算注意力,所以这种形式可以看作是外部注意力。当我们把外部u替换成序列本身(或部分本身),这种形式就可以看作为内部注意力(internal attention)。

我们根据文章[7]中的例子来看看这个过程,例如句子:”Volleyball match is in progress between ladies”。句子中其它单词都依赖着”match”,理想情况下,我们希望使用自我注意力来自动捕获这种内在依赖。换句话说,自注意力可以解释为,每个单词 vi 去和V序列中的内部模式 v′ ,匹配函数 a(v′,vi) 。 v′ 很自然的选择为V中其它单词 vj ,这样遍可以计算成对注意力得分。为了完全捕捉序列中单词之间的复杂相互作用,我们可以进一步扩展它以计算序列中每对单词之间的注意力。这种方式让每个单词和序列中其它单词交互了关系。

另一方面,自注意力还可以自适应方式学习复杂的上下文单词表示。譬如经典文章:”A structured self-attentive sentence embedding”。这篇文章提出了一种通过引入自注意力机制来提取可解释句子嵌入的新模型。 使用二维矩阵而不是向量来代表嵌入,矩阵的每一行都在句子的不同部分,想深入了解的可以去看看这篇文章,另外,文章的公式感觉真的很漂亮。

值得一提还有2017年谷歌提出的Transformer[6],这是一种新颖的基于注意力的机器翻译架构,也是一个混合神经网络,具有前馈层和自注意层。论文的题目挺霸气:Attention is All you Need,毫无疑问,它是2017年最具影响力和最有趣的论文之一。那这篇文章的Transformer的庐山真面目到底是这样的呢?

这篇文章为提出许多改进,在完全抛弃了RNN的情况下进行seq2seq建模。接下来一起来详细看看吧。


Key, Value and Query:

众所周知,在NLP任务中,通常的处理方法是先分词,然后每个词转化为对应的词向量。接着一般最常见的有二类操作,第一类是接RNN(变体LSTM、GRU、SRU等),但是这一类方法没有摆脱时序这个局限,也就是说无法并行,也导致了在大数据集上的速度效率问题。第二类是接CNN,CNN方便并行,而且容易捕捉到一些全局的结构信息。很长一段时间都是以上二种的抉择以及改造,知道谷歌提供了第三类思路:纯靠注意力,也就是现在要讲的这个东东。

将输入序列编码表示视为一组键值对(K,V)以及查询 Q,因为文章取K=V=Q,所以也自然称为Self Attention。

K, V像是key-value的关系从而是一一对应的,那么上式的意思就是通过Q中每个元素query,与K中各个元素求内积然后softmax的方式,来得到Q中元素与V中元素的相似度,然后加权求和,得到一个新的向量。其中因子 n 为了使得内积不至于太大。以上公式在文中也称为点积注意力(scaled dot-product attention):输出是值的加权和,其中分配给每个值的权重由查询的点积与所有键确定

而Transformer主要由多头自注意力(Multi-Head Self-Attention)单元组成。 在NMT的上下文中,键和值都是编码器隐藏状态。 在解码器中,先前的输出被压缩成查询Q,并且通过映射该查询以及该组键和值来产生下一个输出。

3.3 Memory-based Attention

Memory-based Attention又是什么呢?我们先换种方式来看前面的注意力,假设有一系列的键值对 (ki,vi) 存在内存中和查询向量q,这样便能重写为以下过程:

这种解释是把注意力作为使用查询q的寻址过程,这个过程基于注意力分数从memory中读取内容。聪明的童鞋肯定已经发现了,如果我们假设ki=vi ,这不就是前面谈到的基础注意力么?然而,由于结合了额外的函数,可以实现可重用性和增加灵活性,所以Memory-based attention mechanism可以设计得更加强大。

那为什么又要这样做呢?在nlp的一些任务上比如问答匹配任务,答案往往与问题间接相关,因此基本的注意力技术就显得很无力了。那处理这一任务该如何做才好呢?这个时候就体现了Memory-based attention mechanism的强大了,譬如Sukhbaatar[18]等人通过迭代内存更新(也称为多跳)来模拟时间推理过程,以逐步引导注意到答案的正确位置:

在每次迭代中,使用新内容更新查询,并且使用更新的查询来检索相关内容。一种简单的更新方法为相加 qt+1=qt+ct 。那么还有其它更新方法么?当然有,直觉敏感的童鞋肯定想到了,光是这一点,就可以根据特定任务去设计,比如Kuma[13]等人的工作。这种方式的灵活度也体现在key和value可以自由的被设计,比如我们可以自由地将先验知识结合到key和value嵌入中,以允许它们分别更好地捕获相关信息。看到这里是不是觉得文章灌水其实也不是什么难事了。

3.4 Soft/Hard Attention

这个概念由《Show, Attend and Tell: Neural Image Caption Generation with Visual Attention》提出,这是对attention另一种分类。SoftAttention本质上和Bahdanau等人[3]很相似,其权重取值在0到1之间,而Hard Attention取值为0或者1。

3.5 Global/Local Attention

Luong等人[4]提出了Global Attention和Local Attention。Global Attention本质上和Bahdanau等人[3]很相似。Global方法顾名思义就是会关注源句子序列的所有词,具体地说,在计算语义向量时,会考虑编码器所有的隐藏状态。而在Local Attention中,计算语义向量时只关注每个目标词的一部分编码器隐藏状态。由于Global方法必须计算源句子序列所有隐藏状态,当句子长度过长会使得计算代价昂贵并使得翻译变得不太实际,比如在翻译段落和文档的时候。

参考文献

[1] Attention? Attention!.

[2] Neural Machine Translation (seq2seq) Tutorial.

[3] Neural Machine Translation by Jointly Learning to Align and Translate, Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. ICLR, 2015.

[4] Effective approaches to attention-based neural machine translation, Minh-Thang Luong, Hieu Pham, and Christopher D Manning. EMNLP, 2015.

[5] Neural Turing Machines, Alex Graves, Greg Wayne and Ivo Danihelka. 2014.

[6] Attention Is All You Need, Ashish Vaswani, et al. NIPS, 2017.

[7] An Introductory Survey on Attention Mechanisms in NLP Problems Dichao Hu, 2018.

[8] Coupled Multi-Layer Attentions for Co-Extraction of Aspect and Opinion Terms Wenya Wang,Sinno Jialin Pan, Daniel Dahlmeier and Xiaokui Xiao. AAAI, 2017.

[9] Hierarchical attention networks for document classification Zichao Yang et al. ACL, 2016.

[10] A Nested Attention Neural Hybrid Model for Grammatical Error Correction Jianshu Ji et al. 2017.

[11] Long Short-Term Memory-Networks for Machine Reading Jianpeng Cheng, Li Dong and Mirella Lapata. EMNLP, 2016.

[12] Show, Attend and Tell: Neural Image Caption Generation with Visual Attention Kelvin Xu et al. JMLR, 2015.

[13] Ask me anything: Dynamic memory networks for natural language processing. Zhouhan Lin al. JMLR, 2016.

[14] A structured self-attentive sentence embedding Zhouhan Lin al. ICLR, 2017.

[15] Learning Sentence Representation with Guidance of Human Attention Shaonan Wang , Jiajun Zhang, Chengqing Zong. IJCAI, 2017.

[16] Sequence to Sequence Learning with Neural Networks Ilya Sutskever et al. 2014.

[17] Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation Kyunghyun Cho, Yoshua Bengio et al. EMNLP, 2014.

[18] End-To-End Memory Networks Sainbayar Sukhbaatar et al. NIPS, 2015.

[19] 《Attention is All You Need》浅读(简介+代码)

Swin Transformer 代码详解

code:https://github.com/microsoft/Swin-Transformer

代码详解: https://zhuanlan.zhihu.com/p/367111046

预处理:

对于分类模型,输入图像尺寸为 224×224×3 ,即 H=W=224 。按照原文描述,模型先将图像分割成每块大小为 4×4 的patch,那么就会有 56×56 个patch,这就是初始resolution,也是后面每个stage会降采样的维度。后面每个stage都会降采样时长宽降到一半,特征数加倍。按照原文及原图描述,划分的每个patch具有 4×4×3=48 维特征。

  • 实际在代码中,首先使用了PatchEmbed模块(这里的PatchEmbed包括上图中的Linear Embedding 和 patch partition层),定义如下:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): # embed_dim就是上图中的C超参数
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

可以看到,实际操作使用了一个卷积层conv2d(3, 96, 4, 4),直接就做了划分patch和编码初始特征的工作,对于输入 x:B×3×224×224 ,经过一层conv2d和LayerNorm得到 x:B×562×96 。然后作为对比,可以选择性地加上每个patch的绝对位置编码,原文实验表示这种做法不好,因此不会采用(ape=false)。最后经过一层dropout,至此,预处理完成。另外,要注意的是,代码和上面流程图并不符,其实在stage 1之前,即预处理完成后,维度已经是 H/4×W/4×C ,stage 1之后已经是 H/8×W/8×2C ,不过在stage 4后不再降采样,得到的还是 H/32×W/32×8C 。

stage处理

我们先梳理整个stage的大体过程,把简单的部分先说了,再深入到复杂得的细节。每个stage,即代码中的BasicLayer,由若干个block组成,而block的数目由depth列表中的元素决定。每个block就是W-MSA(window-multihead self attention)或者SW-MSA(shift window multihead self attention),一般有偶数个block,两种SA交替出现,比如6个block,0,2,4是W-MSA,1,3,5是SW-MSA。在经历完一个stage后,会进行下采样,定义的下采样比较有意思。比如还是 56×56 个patch,四个为一组,分别取每组中的左上,右上、左下、右下堆叠一起,经过一个layernorm,linear层,实现维度下采样、特征加倍的效果。实际上它可以看成一种加权池化的过程。代码如下:

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

在经历完4个stage后,得到的是 (H/32×W/32)×8C 的特征,将其转到 8C×(H/32×W/32) 后,接一个AdaptiveAvgPool1d(1),全局平均池化,得到 8C 特征,最后接一个分类器。

PatchMerging

Block处理

SwinTransformerBlock的结构,由LayerNorm层、windowAttention层(Window MultiHead self -attention, W-MSA)、MLP层以及shiftWindowAttention层(SW-MSA)组成。

上面说到有两种block,block的代码如下:

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        # 左图中最下边的LN层layerNorm层
        self.norm1 = norm_layer(dim)
        # W_MSA层或者SW-MSA层,详细的介绍看WindowAttention部分的代码
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # 左图中间部分的LN层
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        # 左图最上边的MLP层
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        # 这里利用shift_size控制是否执行shift window操作
        # 当shift_size为0时,不执行shift操作,对应W-MSA,也就是在每个stage中,W-MSA与SW-MSA交替出现
        # 例如第一个stage中存在两个block,那么第一个shift_size=0就是W-MSA,第二个shift_size不为0
        # 就是SW-MSA
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
#slice() 函数实现切片对象,主要用在切片操作函数里的参数传递。class slice(start, stop[, step])
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
## 上述操作是为了给每个窗口给上索引

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        # 如果需要计算 SW-MSA就需要进行循环移位。
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
#shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

W-MSA

W-MSA比较简单,只要其中shift_size设置为0就是W-MSA。下面跟着代码走一遍过程。

  • 输入: x:B×562×96 , H,W=56
  • 经过一层layerNorm
  • 变形: x:B×56×56×96
  • 直接赋值给shifted_x
  • 调用window_partition函数,输入shifted_xwindow_size=7
  • 注意窗口大小以patch为单位,比如7就是7个patch,如果56的分辨率就会有8个窗口。
  • 这个函数对shifted_x做一系列变形,最终变成 82B×7×7×96
  • 返回赋值给x_windows,再变形成 82B×72×96 ,这表示所有图片,每个图片的64个window,每个window内有49个patch。
  • 调用WindowAttention层,这里以它的num_head为3为例。输入参数为x_windowsself.attn_mask,对于W-MSA,attn_mask为None,可以不用管。

WindowAttention代码如下:

代码中使用7×7的windowsize,将feature map分割为不同的window,在每个window中计算自注意力。

Self-attention的计算公式(B为相对位置编码)

绝对位置编码是在进行self-attention计算之前为每一个token添加一个可学习的参数,相对位置编码如上式所示,是在进行self-attention计算时,在计算过程中添加一个可学习的相对位置参数。

假设window_size = 2*2即每个窗口有4个token (M=2) ,如图1所示,在计算self-attention时,每个token都要与所有的token计算QK值,如图6所示,当位置1的token计算self-attention时,要计算位置1与位置(1,2,3,4)的QK值,即以位置1的token为中心点,中心点位置坐标(0,0),其他位置计算与当前位置坐标的偏移量。

坐标变换
坐标变换
相对位置索引求解流程图

最后生成的是相对位置索引,relative_position_index.shape = (M2,M2) ,在网络中注册成为一个不可学习的变量,relative_position_index的作用就是根据最终的索引值找到对应的可学习的相对位置编码。relative_position_index的数值范围(0~8),即 (2M−1)∗(2M−1) ,所以相对位置编码(relative position bias table)可以由一个3*3的矩阵表示,如图7所示:这样就根据index对应位置的索引找到table对应位置的值作为相对位置编码。

图7 相对位置编码

图7中的0-8为索引值,每个索引值都对应了 M2 维可学习数据(每个token都要计算 M2 个QK值,每个QK值都要加上对应的相对位置编码)

继续以图6中 M=2 的窗口为例,当计算位置1对应的 M2 个QK值时,应用的relative_position_index = [ 4, 5, 7, 8] (M2)个 ,对应的数据就是图7中位置索引4,5,7,8位置对应的 M2 维数据,即relative_position.shape = (M2∗M2)

相对位置编码在源码WindowAttention中应用,了解原理之后就很容易能够读懂程序:

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim # 输入通道的数量
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH  初始化表

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0]) # coords_h = tensor([0,1,2,...,self.window_size[0]-1])  维度=Wh
        coords_w = torch.arange(self.window_size[1]) # coords_w = tensor([0,1,2,...,self.window_size[1]-1])  维度=Ww

        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww


        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1

        '''
        后面我们需要将其展开成一维偏移量。而对于(2,1)和(1,2)这两个坐标,在二维上是不同的,但是通过将x\y坐标相加转换为一维偏移的时候
        他们的偏移量是相等的,所以需要对其做乘法操作,进行区分
        '''

        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 计算得到相对位置索引
        # relative_position_index.shape = (M2, M2) 意思是一共有这么多个位置
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww 

        '''
        relative_position_index注册为一个不参与网络学习的变量
        '''
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        '''
        使用从截断正态分布中提取的值填充输入张量
        self.relative_position_bias_table 是全0张量,通过trunc_normal_ 进行数值填充
        '''
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            N: number of all patches in the window
            C: 输入通过线性层转化得到的维度C
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        '''
        x.shape = (num_windows*B, N, C)
        self.qkv(x).shape = (num_windows*B, N, 3C)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).shape = (num_windows*B, N, 3, num_heads, C//num_heads)
        self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C//num_heads)
        '''
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        '''
        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C//num_heads)
        N = M2 代表patches的数量
        C//num_heads代表Q,K,V的维数
        '''
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # q乘上一个放缩系数,对应公式中的sqrt(d)
        q = q * self.scale

        # attn.shape = (num_windows*B, num_heads, N, N)  N = M2 代表patches的数量
        attn = (q @ k.transpose(-2, -1))

        '''
        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)
        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)
        self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
        self.relative_position_index是计算出来不可学习的量
        '''
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

        '''
        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 代表patches的数量
        .unsqueeze(0):扩张维度,在0对应的位置插入维度1
        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)
        num_windows*B 通过广播机制传播,relative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) 的维度1会broadcast到数量num_windows*B
        表示所有batch通用一个索引矩阵和相对位置矩阵
        '''
        attn = attn + relative_position_bias.unsqueeze(0)

        # mask.shape = (num_windows, M2, M2)
        # attn.shape = (num_windows*B, num_heads, M2, M2)
        if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        '''
        v.shape = (num_windows*B, num_heads, M2, C//num_heads)  N=M2 代表patches的数量, C//num_heads代表输入的维度
        attn.shape = (num_windows*B, num_heads, M2, M2)
        attn@v .shape = (num_windows*B, num_heads, M2, C//num_heads)
        '''
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   # B_:num_windows*B  N:M2  C=num_heads*C//num_heads

        #   self.proj = nn.Linear(dim, dim)  dim = C
        #   self.proj_drop = nn.Dropout(proj_drop)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x  # x.shape = (num_windows*B, N, C)  N:窗口中所有patches的数量

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

在上述程序中有一段mask相关程序:

if mask is not None:
            nW = mask.shape[0]
            # attn.view(B_ // nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) 第一个M2代表有M2个token,第二个M2代表每个token要计算M2次QKT的值
            # broadcast相加
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # attn.shape = (B, num_windows, num_heads, M2, M2)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

这个部分对应的是Swin Transformer Block 中的SW-MSA

  • 输入 x:82B×72×96 。
  • 产生 QKV ,调用线性层后,得到 82B×72×(96×3) ,拆分给不同的head,得到 82B×72×3×3×32 ,第一个3是 QKV 的3,第二个3是3个head。再permute成 3×82B×3×72×32 ,再拆解成 q,k,v ,每个都是 82B×3×72×32 。表示所有图片的每个图片64个window,每个window对应到3个不同的head,都有一套49个patch、32维的特征。
  • q 归一化
  • qk 矩阵相乘求特征内积,得到 attn:82B×3×72×72
  • 得到相对位置的编码信息relative_position_bias
    • 代码如下:
self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
  • 这里以window_size=3为例,解释以下过程:首先生成 coords:2×3×3 ,就是在一个 3×3 的窗口内,每个位置的 y,x 坐标,而relative_coords为 2×9×9 ,就是9个点中,每个点的 y 或 x 与其他所有点的差值,比如 [0][3][1] 表示3号点(第二行第一个点)与1号点(第一行第二个点)的 y 坐标的差值。然后变形,并让两个坐标分别加上 3−1=2 ,是因为这些坐标值范围 [0,2] ,因此差值的最小值为-2,加上2后从0开始。最后让 y 坐标乘上 2×3−1=5 ,应该是一个trick,调整差值范围。最后将两个维度的差值相加,得到relative_position_index, 32×32 ,为9个点之间两两之间的相对位置编码值,最后用来到self.relative_position_bias_table中寻址,注意相对位置的最大值为 (2M−2)(2M−1) ,而这个table最多有 (2M−1)(2M−1) 行,因此保证可以寻址,得到了一组给多个head使用的相对位置编码信息,这个table是可训练的参数。
  • 回到代码中,得到的relative_position_bias为 3×72×72
  • 将其加到attn上,最后一个维度softmax,dropout
  • 与 v 矩阵相乘,并转置,合并多个头的信息,得到 82B×72×96
  • 经过一层线性层,dropout,返回
  • 返回赋值给attn_windows,变形为 82B×7×7×96
  • 调用window_reverse,打回原状: B×56×56×96
  • 返回给 x ,经过FFN:先加上原来的输入 x 作为residue结构,注意这里用到timmDropPath,并且drop的概率是整个网络结构线性增长的。然后再加上两层mlp的结果。
  • 返回结果 x 。

这样,整个过程就完成了,剩下的就是SW-MSA的一些不同的操作。

  1. 首先将windows进行半个窗口的循环移位,上图中的1, 2步骤,使用torch.roll实现。
  2. 在相同的窗口中计算自注意力,计算结果如下右图所示,window0的结构保存,但是针对window2的计算,其中3与3、6与6的计算生成了attn mask 中window2中的黄色区域,针对windows2中3与6、6与3之间不应该计算自注意力(attn mask中window2的蓝色区域),将蓝色区域mask赋值为-100,经过softmax之后,起作用可以忽略不计。同理window1与window3的计算一致。
  3. 最后再进行循环移位,恢复原来的位置。

原论文图中的Stage和程序中的一个Stage不同:

程序中的BasicLayer为一个Stage,在BasicLayer中调用了上面讲到的SwinTransformerBlock和PatchMerging模块:

class BasicLayer(nn.Module):  # 论文图中每个stage里对应的若干个SwinTransformerBlock
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth # swin_transformer blocks的个数
        self.use_checkpoint = use_checkpoint

        # build blocks  从0开始的偶数位置的SwinTransformerBlock计算的是W-MSA,奇数位置的Block计算的是SW-MSA,且shift_size = window_size//2
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)  # blk = SwinTransformerBlock
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

Part 3 : 不同视觉任务输出

程序中对应的是图片分类任务,经过Part 2 之后的数据通过 norm/avgpool/flatten:

 x = self.norm(x)  # B L C
 x = self.avgpool(x.transpose(1, 2))  # B C 1
 x = torch.flatten(x, 1) # B C

之后通过nn.Linear将特征转化为对应的类别:

self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

应用于其他不同的视觉任务时,只需要将输出进行特定的修改即可。

完整的SwinTransformer程序如下:

class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes # 1000
        self.num_layers = len(depths) # [2, 2, 6, 2]  Swin_T 的配置
        self.embed_dim = embed_dim # 96
        self.ape = ape # False
        self.patch_norm = patch_norm # True
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))  # 96*2^3
        self.mlp_ratio = mlp_ratio # 4

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features) # norm_layer = nn.LayerNorm
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)  # 使用self.apply 初始化参数

    def _init_weights(self, m):
        # is_instance 判断对象是否为已知类型
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)  # x.shape = (H//4, W//4, C)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)  # self.pos_drop = nn.Dropout(p=drop_rate)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1) # B C
        return x

    def forward(self, x):
        x = self.forward_features(x)  # x是论文图中Figure 3 a图中最后的输出
        #  self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        x = self.head(x) # x.shape = (B, num_classes)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

补充:有关swin transformer相对位置编码:

VIT

Dosovitskiy et al. An image is worth 16×16 words: transformers for image recognition at scale. In ICLR, 2021

step1 :分割图片

step2 向量化:从九个快变成九个向量

step3:向量线性变换:(linear embedding线性嵌入层)

step4:将位置编码添加到z上:

step4:添加一个cls向量:

step5:只利用cls的输出

按照上面的流程图,一个ViT block可以分为以下几个步骤

(1) patch embedding:例如输入图片大小为224×224,将图片分为固定大小的patch,patch大小为16×16,则每张图像会生成224×224/16×16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196×768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197×768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197×768

(3) LN/multi-head attention/LN:LN输出维度依然是197×768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197×768,如果有12个头(768/12=64),则qkv的维度是197×64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197×768,然后在过一层LN,维度依然是197×768

(4) MLP:将维度放大再缩小回去,197×768放大为197×3072,再缩小变为197×768

一个block之后维度依然和输入相同,都是197×768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 zL0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit需要预训练+微调

• Pretrain the model on Dataset A, fine-tune the model on Dataset B,
and evaluate the model on Dataset B.
• Pretrained on ImageNet (small), ViT is slightly worse than ResNet.
• Pretrained on ImageNet-21K (medium), ViT is comparable to ResNet.
• Pretrained on JFT (large), ViT is slightly better than ResNet.

效果: