CycleMLP:用于密集预测的类似 MLP 的架构

paper:https://arxiv.org/abs/2107.10224

作者单位:香港大学, 商汤科技
代码:https://github.com/ShoufaChen/CycleMLP

核心:用 Cycle-FC来替换Spatial FC(计算量大且网络对于不同图像分辨率的输入不可接受,且不能用于下游任务)

本文提出了一个简单的 MLP-like 的架构 CycleMLP,它是视觉识别和密集预测的通用主干,不同于现代 MLP 架构,例如 MLP-Mixer、ResMLP 和 gMLP,其架构与图像大小相关,因此是在目标检测和分割中不可行。

与现代方法相比,CycleMLP 有两个优势。

(1) 可以应对各种图像尺寸。

(2) 利用局部窗口实现对图像大小的线性计算复杂度。

相比之下,以前的 MLP 具有二次计算,因为它们具有完全的空间连接。

单个 CycleMLP Block 依然是分为 Token-mixing MLP 和 Channel mixing MLP,其中作者主要的贡献点在于替换 MLP-mixer 的 Token-mixing MLP 为 Cycle-FC。所以整个 CycleMLP Block 可以描述为:

何为 Cycle-FC ?要回答这个问题,我们首先来回顾一下 Channel FC 以及 Spatial FC.

Channel FC 即通道方向的映射,等效与1×1 卷积,其参数量与图像尺寸无关,而与通道数(token 维度)有关。假设输入输出特征图尺寸一致,则参数量为 C^2,其中 C 为通道数。而计算量则为 HWC^2,其中 H W 分别为特征图的高和宽。如果只考虑计算量与图像尺寸的影响的话,则为 O ( H W ) 。

Spatial FC 即 MLP-Mixer 使用的 Token-mixing 全连接层,在这里我们都是只考虑一个全连接层,则其实现的是 H W − > H W 的映射,参数量为 H^2W^2,计算量也为 H 2 W 2 C H^2W^2C,如果只考虑计算量与图像尺寸的影响的话,则为 O(H^2W^2)。并且HW 大小固定,网络对于不同图像分辨率的输入不可接受,且不能用于下游任务以及使用类似 EfficientNetV2 等的多分辨率训练策略。

为什么我们可以在复杂度分析时只考虑 H W 的影响呢?因为在金字塔结构的 MLP 中,通常一开始的 patch size 为 4,然后输入尺寸为 224×224,则一开始的 H = W = 56 = 224 / 4 ,而 C = 64 或者 96 ,所以C≪HW。如果对于下游任务而言,例如输入变为了512×512,则它们之间的差距更大了。为此在这里我们可以在复杂度分析中暂时只考虑 H W  而忽略 C 。

为了同时克服 Spatial 对于图像输入尺寸敏感以及计算量大的问题,作者提出了 Cycle-FC。其只是用通道方向的映射并且计算量和 Channel FC 保持一致。其说白了就是不断地以 [+1 0 -1 0 +1 0 -1 0 +1 …] 的方式移动特征图,将不同空间位置的特征对齐到同一个通道上,然后使用1×1 卷积。

回忆 AS-MLP,其采用的特征图移动方式则为 [+1 0 -1 +1 0 -1 +1 0 -1] 这样的成组方式,CycleMLP 则是使用“楼梯型”方式,但是其思想没有本质不同。此外,AS-MLP 确实对特征图进行了 Shift,并且采用了 zero-padding,而 CycleMLP 在具体实现过程中则是使用可变形卷积加以实现的。我个人对于 AS-MLP 与 CycleMLP 的理解如下图所示,可见他们其实核心思想是一致的。

from torchvision.ops.deform_conv import deform_conv2d

img3

CycleMLP 与 AS-MLP 只并行 H 和 W 方向的移动不同,CycleMLP 其实是三条支路并行:H 方向,W 方向,以及不移动特征图做通道方向映射。此外,AS-MLP 在一开始还做了一次 Channel Projection 进行降维。

img5

CycleMLP 最终使用的和 ViP 一样,使用 Split Attention 来融合三条支路

class CycleMLP(nn.Module):
    def __init__(self, dim, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)

        self.sfc_h = CycleFC(dim, dim, (1, 3), 1, 0)
        self.sfc_w = CycleFC(dim, dim, (3, 1), 1, 0)

        self.reweight = Mlp(dim, dim // 4, dim * 3)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape
        h = self.sfc_h(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        w = self.sfc_w(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        c = self.mlp_c(x)

        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)

        x = h * a[0] + w * a[1] + c * a[2]

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

最后提一句,作者将投影区间定义为是 Pseudo-Kernel,这其实也是我们常说的 感受野 一词。

img4

2.2 整体网络结构

CycleMLP 的 Patch Embedding 也很有特色,使用卷积核大小为 7×7 ,步长为 4 的卷积。后续 Hire-MLP 其实也是这样进行的 Patch Embedding。相比而言 Swin 使用卷积核大小为 4×4,步长为 4 的卷积。在近期的我自己的小实验中也发现:Patch Embedding 时具有重叠会更好,这样可以避免边界效应并在小数据集上提升性能。CycleMLP 中间采用多阶段金字塔模型,总共分为 4 个阶段,每个阶段交替重复使用 CycleMLP Block。下采样使用卷积核大小为3×3,步长为 2 的卷积,这样做也有重叠,Hire-MLP 也是这样子哈。最后经过全局池化后连接一个全连接分类器即可。作者一共提出来了四种配置:

请添加图片描述

在这四种配置,Si​ 指 Patch Embedding 中的 Patch size,Ci​ 指 Patch Embedding 的输出编码特征维度,E i ​ 为 Channel-mixing MLP 中两个全连接层中第一个全连接层的 expand radio,Li​ 则是不同 Stage 中 Block 的重复次数。

3. 下游任务实验

CycleMLP 旨在为 MLP 模型的目标检测、实例分割和语义分割提供一个有竞争力的基线。与 AS-MLP 不同之处在于,CycleMLP 在 ADE20K 上进行实验,而 AS-MLP 在 COCO 上进行的实验。这真的是巧合,还是故意避开?不敢问也不敢说。

目标检测性能表现:相比 PVT,CycleMLP 都更具有优势。

请添加图片描述

语义分割性能表现:特别是,CycleMLP 在 ADE20K val 上达到了 45.1 mIoU,与 Swin (45.2 mIOU) 相当。

请添加图片描述

4. 消融实验

作者一共进行了三组消融实验:

  • Cycle-FC VS Spacial-FC and Channel-FC: 作者将 CycleMLP 中的 Cycle-FC 替换为 Spacial-FC 或者 Channel-FC,结果发现 CycleMLP 具有更好的性能。但是只有 Channel-FC,也能达到 79.4% 的性能,真的这么高吗,比 ResNet 高那么多…
请添加图片描述
  • Cycle-FC 中三条支路的选择:Cycle-FC 中作者并行了三条支路,对他们的消融实验发现,同时拥有正交 H 和 W 方向效果很好,加上不动之后效果更好。两倍 H 方向或者两倍 W 方向比仅含有 H 或者 W 方向会好一些。
请添加图片描述
  • 测试分辨率的影响:最终发现测试正确率随分辨率先升后降,CycleMLP 表现最好。
请添加图片描述

4. 总结与反思

CycleMLP 提出了 Cycle-FC,即将不同 token 的特征对齐到同一个通道,然后使用通道映射,从而实现网络参数量计算量的降低,以及对图像分辨率不敏感。CycleMLP 也在下游任务上测试了自己的性能表现。整体而言做得还是很充分的。不过其试图造一些新的名词以强化贡献,例如 Cycle-FC 其实就是移动特征图,Pseudo-Kernel 其实就是卷积核感受野的概念。最终 CycleMLP 通过三条并行的支路构建了十字形感受野。相比 AS-MLP,CycleMLP 在感受野分析上略显不足,没有更泛化地分析以及进行消融实验。比如 CycleMLP 也可以间隔采样,例如 [+4 +2 0 -2 -4 -2 0 2 4 2 0 -2 …],就可以构建 AS-MLP 那种空洞的更大范围的感受野。(最后插一句:CycleMLP 和 AS-MLP,就像 ResMLP 与 MLP-Mixer,学术界的 Idea 真的能够这么惊人的一致吗?)

AS-MLP:首个检测与分割领域MLP架构

paper: https://arxiv.org/abs/2107.08391

github:https://github.com/svip-lab/AS-MLP

本文是上海科技大学在MLP架构方面的探索,它设计了一种轴向移位操作以便于进行空间信息交互。在架构方面,AS-MLP采用了类似PVT的分层架构,因为可以轻易的迁移到下游任务。所提方法在ImageNet数据集上取得了优于其他MLP架构的性能,在COC检测与ADE20K分割任务上取得了与Swin相当的性能。值得一提的是,AS-MLP是首个迁移到下游任务的MLP架构。注:CycleMLP与AS-MLP属于同一时期的工作,发到arxiv的时间也只差两天,说两者都是首个其实也可以。

本文提出了一种轴向移动架构AS-MLP(Axial Shifted MLP)用于不同的视觉任务(包含图像分类、检测以及分割)。不同于MLP-Mixer通过矩阵转置+词混叠MLP进行全局空域特征编码,我们在局部特征通信方向投入了更多的关注。

通过轴向移动特征信息,AS-MLP可以得到不同方向的信息流,这有助于捕获局部相关性。该操作使得我们采用纯MLP架构即可取得与CNN相同的感受野。我们还可以类似卷积核设置AS-MLP模块的感受野尺寸以及扩张因子。如此简单而有效的架构取得了优于其他MLP架构的性能,同时具有与Transformer架构(比如Swin Transformer)相当的性能,甚至具有稍少的FLOPs。比如,AS-MLP在ImageNet数据集上凭借88M参数量+15.2GFLOPs取得了83.3%top1精度,且无需额外训练数据。

此外,所提AS-MLP也是首个用于下游任务(如目标检测、语义分割)的MLP架构。AS-MLP在COC验证集上取得了51.5mAP指标,在ADE20K数据集上取得了49.5mIoU指标,具有与Transformer架构相当的性能。

Method:

Comparisons between AS-MLP, Convolution, Transformer and MLP-Mixer

在这里,我们将AS-MLP、卷积、Swin以及MLP-Mixer进行对比分析。尽管这些模型是从不同角度出发设计得到,但它们均基于给定输出位置点,其值依赖于局部特征的加权。这些采样位置包含局部依赖与长距离依赖。

从上述对比图可以看到:

  • 卷积是一种局部感受野的操作,更适合于提取具有局部依赖关系的特征;
  • Swin同样是一种局部感受野操作,Swin为自注意力机制引入了局部性提升了Transformer架构的性能,同时也降低了计算复杂度;
  • MLP-Mixer是一种全局感受野操作,它仅仅由矩阵转置与MLP操作构成;
  • AS-MLP是一种局部“十”字感受野操作,它可以更好的提取局部依赖关系。

Variants of AS-MLP Architecture

前面的Figure仅仅给出了Tiny版本的AS-MLP架构,参考DeiT与Swin,我们通过调整模块数与通道数构建了不同大小的模型。

image.png

Experiments

ImageNet Classification

image.png

上表给出了所提方法在ImageNet数据上的性能对比,从中可以看到:

  • 所提AS-MLP取得了比其他MLP架构更优的性能,同时具有相似的参数量与FLOPs;
  • AS-MLP-S取得了83.1%的top1精度同时具有比Mixer-B/16、ViP-Medium/7更少的参数量;
  • 此外,AS-MLP-B取得了与Swin相当的性能:83.3%。
image.png

此外,我们还对比了端侧配置版本的AS-MLP,结果见上表。可以看到:在端侧配置下,所提方法大幅超越了Swin Transformer。

COCO Detection

image.png

上表对比了COCO检测任务上的性能对比,可以看到:

  • 所提AS-MLP是首个用于下游任务的MLP架构;
  • 所提AS-MLP取得了与Swin相当的性能。具体来说,在Cascade Mask R-CNN+Swin-B取得了51.9AP指标,参数量为145M;而AS-MLP-B取得了51。5AP指标,参数量为145M。

ADE20K Segmentation

image.png

上表给出了ADE20K分割任务上的性能对比,从中可以看到:

  • 所提AS-MLP同样是首个用于分割任务的MLP架构;
  • AS-MLP-T取得了比Swin-T等有的性能,同时具有稍少FLOPs;
  • UperNet+Swin-B取得了49.7mIoU,参数量为121M,计算量为1188GFLOPs;而UperNet+AS-MLP-B取得了49.5mIoU,参数量121M,计算量为1166GFLOPs。

Ablation Study

AS-MLP的核心是轴向移动,接下来我们将对其不同成分进行消融分析,所有试验均基于AS-MLP-T实现。

image.png

上表对比了不同padding方式、不同移动尺寸以及不同扩展比例的性能对比,从中可以看到:

  • zero-padding更适合于AS-MLP设计;
  • 提升扩张因子会轻微降低模型性能;
  • 提升移动尺寸,模型精度会先上升后下降。
  • 基于上述分析,我们采用shift=5,zero-padding,dilation=1。

image.png
我们同时还比较了AS-MLP模块的不同链接类型,结果见上表,从中可以看到:在不同移动尺寸下,并行连接总是具有比串行连接更佳性能

Comparsion with S2MLP

在初看到该文时,第一感觉这个与百度的那篇S2MLP(见下图核心模块)真的非常相似,都是采用了垂直、水平移位方式进行空间信息交互,而且还都是上下左右四个方向。可惜AS-MLP并未与S2MLP进行对比,反而比较晚(指的是见刊arxiv)的ViP进行的对比。

image.png

既然提到了,我们还是对S2MLP与ASMLP进行一下对比吧。

  • 在整体架构方面,AS-MLP采用了类似PVT的分层架构,而S2MLP一文则是采用了类似ViT的柱状架构;
  • 在应用方面,AS-MLP即可应用于图像分类,还可以迁移到下游任务中;而S2MLP则仅适用于图像分类,并不适用下游任务;
  • 在核心模型方面,AS-MLP采用并行垂直、水平移动,分别进行特征汇聚后再进行特征相加汇聚;而S2MLP则采用分组方式,不同组进行不同方向的移动,然后再进行空间信息汇聚;
  • 在模型性能方面,AS-MLP取得了与Swin相当的性能,比ViP更优的性能;而S2MLP的性能则弱于Swin与ViP;
  • 最后一点,AS-MLP开源了,但S2MLP并未开源。

Vision MLP —Swin-MLP

code:https://github.com/microsoft/Swin-Transformer

Swin MLP 代码来自 Swin Transformer 的官方实现。Swin Transformer 作者们在已有模型的基础上实现了 Swin MLP 模型,证明了 Window-based attention 对于 MLP 模型的有效性。

把张量 (B, H, W, C) 分成 window (B×H/M×W/M, M, M, C),其中M是 window_size。这一步相当于得到 B×H/M×W/M 个大小为 (M, M, C) 的 window。

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

把 window (B×H/M×W/M, M, M, C) 变回张量 (B, H, W, C)。

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

一个 Swin MLP Block

class SwinMLPBlock(nn.Module):
    r""" Swin MLP Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.padding = [self.window_size - self.shift_size, self.shift_size,
                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b

        self.norm1 = norm_layer(dim)
        # use group convolution to implement multi-head MLP
        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,
                                     self.num_heads * self.window_size ** 2,
                                     kernel_size=1,
                                     groups=self.num_heads)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # shift
        if self.shift_size > 0:
            P_l, P_r, P_t, P_b = self.padding
            shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0)
        else:
            shifted_x = x
        _, _H, _W, _ = shifted_x.shape

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # Window/Shifted-Window Spatial MLP
        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)
        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C//nH
        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,
                                                  C // self.num_heads)
        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C//nH
        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,
                                                       C // self.num_heads).transpose(1, 2)
        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)

        # merge windows
        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C

        # reverse shift
        if self.shift_size > 0:
            P_l, P_r, P_t, P_b = self.padding
            x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

注意 F.pad(x, [0, 0, P_l, P_r, P_t, P_b], “constant”, 0) 的对象是 x,维度是 (B, H, W, C)。
padding相当于是第3维 (C 这一维) 不填充,第2维 (W 这一维) 左右分别填充 P_l, P_r,第1维 (H 这一维) 左右分别填充 P_t, P_b。
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C:
这句代码把 shifted_x 分成 nW*B 个 windows,其中每个 window 的维度是 (window_size, window_size, C)。

# reverse shift
if self.shift_size > 0:
P_l, P_r, P_t, P_b = self.padding
x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
else:
x = shifted_x
这里是如果进行了 shift 操作,则最后取得结果也应该是没有 padding 的部分,正好是 shifted_x[:, P_t:-P_b, P_l:-P_r, :]。

一个 Swin MLP Block 的 FLOPs,注意 WSA 的计算量是:

FLOPs (WSA) = (window_size * window_size)^2 * dim * number_window

def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W

        # Window/Shifted-Window Spatial MLP
        if self.shift_size > 0:
            nW = (H / self.window_size + 1) * (W / self.window_size + 1)
        else:
            nW = H * W / self.window_size / self.window_size
        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

每个 stage 之间的 PatchMerging连接,把 resolution 变为一半,dim 变为2倍。

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def flops(self):
        H, W = self.input_resolution
        # norm
        flops = H * W * self.dim
        # reduction
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops
  • Patch Merging 操作把相邻的 2×2 个 tokens 给合并到一起,得到的 token 的维度是4C。
    Patch Merging 操作再通过一次线性变换把维度降为2C。

一个 Swin MLP Layer

class BasicLayer(nn.Module):
    """ A basic Swin MLP layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        drop (float, optional): Dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., drop=0., drop_path=0.,
                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinMLPBlock(dim=dim, input_resolution=input_resolution,
                         num_heads=num_heads, window_size=window_size,
                         shift_size=0 if (i % 2 == 0) else window_size // 2,
                         mlp_ratio=mlp_ratio,
                         drop=drop,
                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                         norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops
  • 包含 depth 个 Swin MLP Block。
    注意计算 FLOPs 的方式:每个 blk 和 downsample 都自带 flops() 方法,可以直接来调用。

PatchEmbedded 操作

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops
  • 和 ViT 的 Patch Embedded 操作一样,本质上是一个 K=patch size,s=patch size 的 nn.Conv2d 操作,注意卷积 FLOPs 的计算公式即可。

SwinMLP 整体模型架构

class SwinMLP(nn.Module):
    r""" Swin MLP

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin MLP layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        drop_rate (float): Dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               drop=drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv1d)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        # adaptive average pool
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        # head
        flops += self.num_features * self.num_classes
        return flops
  • 由4个 Stage 组成,每个 Stage 由 BasicLayer 实现。
    传入的 depths 代表每个 Stage 的层数,比如 Swin-T 就是:[2, 2, 6, 2]。

Vision MLP –Pay Attention to MLPs

MLP-Mixer的增强版,带gating的MLP。有两个版本,分别是gMLP和aMLP。Pay-Attention-to-MLPs是gMLP版本,同时也提出了gMLP的增强版aMLP。

paper: https://arxiv.org/abs/2105.08050

github: https://github.com/antonyvigouret/Pay-Attention-to-MLPs

此文和最近刊出MLP文章相同,旨在探究self-attention对于Transformer来说是否至关重要。并在CV和NLP上的相关任务进行实验。

Transformer结构具有可并行化汇聚所有token间的空间信息的优点。众所周知self-attention是通过计算输入间的空间关系动态的引入归纳偏置,同时被静态参数化的MLP能表达任意的函数,所以self-attention对于Transformer在CV和NLP等领域的成功是否是至关重要的呢?

  • 此文提出了一个基于MLP的没有self-attention结构名为gMLP,仅仅存在静态参数化的通道映射(channel projections)和空间映射(spatial projections)。同时作者通过实验发现当对空间映射的线性结果进行门机制乘法得到的效果最好
  • 此文使用gMLP做图片分类并在ImageNet上取得了与DeiT、ViT等Transformer模型相当的效果。与先前的MLP模型MLP-Mixer相比,gMLP做到了参数更少(参数减少66%)效果更强(效果提升3%)。
  • 此文使用gMLP做masked language modeling,gMLP采用和Bert一样的设置最小化perplexity取得了和Transformer模型预训练一样好的效果。通过pretraining和finetuning实验发现随着模型容量的增加,gMLP比Transformer提升更大,表明模型相较于self-attention可能对于模型容量的大小更为敏感。
  • 对于需要跨句对齐的微调任务MNLI,gMLP与Transformer相比逊色一筹。对此作者发现加上一个128特征大小的单头注意力足以使得gMLP在任何NLP任务上取得比Transformer更好的效果。

gMLP由L个如下图所示的模块堆叠而成

设每个模块的输入 \(X \in \mathbb{R}^{n \times d}\), n为序列长度, d为特征维度。每个模块表达如下:
\(Z=\sigma(X U), \quad \tilde{Z}=s(Z), \quad Y=\tilde{Z} V\)
\(\sigma\) 是GELU等激活函数, U 和 V 和Transformer中的FFN类似都是线性映射。为了简洁表达上式中 省略了shortcuts, normalizations 和 biases。
上式中最重要的是能捕捉空间交互的 \(s(\cdot)\) 。如果上式去掉 \(s(\cdot)\) 那么将不再能进行空间交互和FFN 并无区别。文中作者选择名为 Spatial Gating Unit (SGU) 的模块作为 \(s(\cdot)\) 捕捉空间依赖。另外,gMLP在NLP、CV任务中遵循与BERT、ViT一样的输入输出规则。

Spatial Gating Unit:

为了能有跨token的交互, \(s(\cdot)\) 操作须在空间维度。可以简单的使用线性映射表示:
\(f_{W, b}(Z)=W Z+b\)
其中 \(W \in \mathbb{R}^{n \times n}\) 表示空间交互的映射参数。在self-attention中 W 是通过 Z 动态计算得到的。 此文对上式使用gating操作以便更好的训练,如下所示:
\(s(Z)=Z \odot f_{W, b}(Z)\)
为了训练更稳定,作者将 W 和 b 分别初始化为接近 0 与 1 来保证在开始训练时 \(f_{W, b} \approx 1\) 、 \(s(Z z) \approx Z\) 使得在开始阶段gMLP近似于FFN并在训练中逐渐学习到跨token的空间信息。
作者进一步发现将 Z 从通道维度分割成两部分 \(\left(Z_1, Z_2\right)\) 进行gating操作更有用,如下所示:
\(
s(Z)=Z_1 \odot f_{W, b}\left(Z_2\right)
\)
另外函数 \(f_{W, b}\)的输入通常需要normalizel以此提升模型的稳定性。

一些思考:这里的SpatialGatingUnit里面用到了一个通道split,然后再将分割后的两部分做乘法,让我想到了NAFnet中的simplegate,这个的作用一是减少计算量(相比于GELU)、另外引入门控机制,在通道维度进行通道交织,对于模型的效果表现很好。

作者进一步分析了SGU与现有的一些操作的相似之处:首先是Gated Linear Units (GLU) 与 SGU的区别在于SGU对spatial dimension而GLU对channel dimension; 其次SGU和
Squeeze-and-Excite (SE) 一样使用hadamard-product,只是SGU并没有跨通道的映射来保 证排列不变性;SGU的空间映射可以看作depthwise convolution不过SGU只学习跨通道只是, 并没有跨通道过滤器;SGU学习的是二阶空间交互 \(z_i z_j\) , self-attention学习的是三阶交互 \(q_i k_j v_k\) , SGU的复杂度为 \(n^2 e / 2\) 而self-attention的复杂度为 \(2 n^2 d_{\text {。 }}\)

实验:

1、Image Classification

此文首先将gMLP应用于图片分类,使用ImageNet数据集而且不使用额外数据。下表首先展示了gMLP用于图片分类的参数,gMLP和ViT/B16一样使用 16×16 个patch,同时采用和DeiT相似的正则化方法防止过拟合。

下表中gMLP与baselines在ImageNet上的结果表示gMLP取得了与视觉Transformer相当的结果,同时与其它MLP视觉模型相比,gMLP取得了准确率、速度权衡下最好的结果。

Masked Language Modeling with BERT:

此文同时将gMLP应用于masked language modeling(MLM)任务,对于预训练和微调任务,模型的输入输出规则都保持与BERT一致。

作者观察到在MLM任务最后学习到的空间映射矩阵总是Toeplitz-like matrics,如下图所示。所以作者认为gMLP是能从数据中学习到平移不变性的概念的,这使得gMLP实质起到了卷积核是整个序列长度的1-d卷积的作用。在接下来的MLM实验中,作者初始 W 为Toeplitz matrix。

Ablation: The Importance of Gating in gMLP for BERT’s Pretraining:下表展示了gMLP的各种变体与Transoformer模型、MLP-Mixer的比较,可以看到gMLP在与Transformer相同模型大小的情况下能达到与Transformer相当的效果。同时gating操作对于空间映射十分有用。同时下图还可视化了模型学习到的空间映射参数。

Case Study: The Behavior of gMLP as Model Size Increases:下表与下图展示了gMLP随着模型增大逐渐能有与Transformer相当的效果,可见Transformer的效果应该主要是依赖于模型尺寸而非self-attention。

  • Ablation: The Usefulness of Tiny Attention in BERT’s Finetuning:从上面的Case Study可以发现gMLP对于需要跨句子连接的finetuing任务可能不及Transformer,所以作者提出了gMLP的增强版aMLP。aMLP相较于gMLP仅增加了一个单头64的self-attention如下图所示:

从下图结果可以发现aMLP相较于gMLP极大提升了效果并在所有task超过了Transformer。

Vision MLP –ResMLP

Feedforward networks for image classification with data-efficient training

我们提出了ResMLP,一种完全基于多层感知机(MLP)进行图像分类的体系结构。 它是一个简单的残差网络,它交替(i)线性层,其中图像 patches在通道之间独立且相同地交互;以及(ii)两层前馈网络,其中通道中的每个 patch独立地相互作用。

CODE:

import torch
import numpy as np
from resmlp import ResMLP

img = torch.ones([1, 3, 224, 224])

model = ResMLP(in_channels=3, image_size=224, patch_size=16, num_classes=1000,
                 dim=384, depth=12, mlp_dim=384*4)

parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)

out_img = model(img)

print("Shape of out :", out_img.shape)  # [B, in_channels, image_size, image_size]

本文作者提出了一种基于全连接层的图像分类网络。网络结构与MLP-Mixer相似,即先将输入图像拆分成若干patch,对每个patch通过全连接层转换为特征嵌入矩阵,该矩阵的两个维度分别表示channel维度(每个局部位置的特征维度)和patch维度(表示局部位置的维度)。首先将该矩阵转置后沿patch维度进行全连接层运算,实现不同patch之间的交互;再沿channel维度进行全连接运算,实现不同channel之间的交互。最后使用池化层和输出层获得分类结果。本文与MLP-Mixer的不同之处在于采用了更强的数据增强方法和蒸馏策略。

当采用现代的训练策略进行训练时,使用大量的数据增广和可选的蒸馏方法,可以在ImageNet上获得令人惊讶的良好精度/复杂度折衷。

Affine仿射变换:

函数名称:diag(x)
函数功能:构建一个n维的方阵,它的主对角线元素值取自向量x,其余元素都为0

Vision MLP系列–RepMLP

RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081)

CVPR 2022

RepMLP Block

Github source: https://github.com/DingXiaoH/RepMLP

最近公开了一系列视觉MLP论文,包括RepMLP、MLP-Mixer、ResMLP、gMLP等。在这个时间点出现关于MLP的一系列讨论是很合理的:

1) Transformer大火,很多研究者在拆解Transformer的过程中多多少少地对self-attention的必要性产生了疑问。去掉了self-attention,自然就剩MLP了。

2) 科学总是螺旋式上升的,“复兴”老方法(比如说另一篇“复兴”VGG的工作,RepVGG)总是喜闻乐见的。

这些论文引发了热烈的讨论,比如:

1) 这些模型到底是不是MLP?

2) 卷积和全连接(FC)的区别和联系是什么?FC是不是卷积,卷积是不是FC?

3) 真正的纯MLP为什么不行?

4) 所以MLP is all you need?

《RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition》。这篇文章讲了一个全连接层找到一份陌生的工作(直接进行feature map的变换),为了与那些已经为这份工作所特化的同胞(卷积层)们竞争,开始“内卷”的故事。

关键贡献在于,RepMLP用卷积去增强FC,既利用其全局性又赋予其局部性,并通过结构重参数化,将卷积融合到FC中去,从而在推理时去除卷积。

论文:RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition

代码:DingXiaoH/RepMLP

1. 为什么真正的纯MLP不太行?

我们一般认为多层感知机(MLP)是至少两层全连接层(FC)堆叠得到的模型,而且一般把同时含有卷积和MLP的模型(或模型中的一个模块)称为CNN。尽管目前大家对什么叫MLP的问题尚有争议(下图),我们不妨先定义一个任何人都会称之为MLP的100%纯MLP:

这个MLP在ImageNet上的输入是(3, 112, 112),第一层将其变为(32, 56, 56),第二层将其变为(64, 28, 28),然后global average pool,然后经过FC映射为1000类。这样总共只有三个FC,毫无疑问是MLP。这三层的参数为:

第一层:3x112x112x32x56x56 = 3.77G 参数

第二层:32x56x56x64x28x28 = 5.03G 参数

第三层:64×1000 = 64k 参数,忽略不计

看起来有点吓人,但这确实是一个处于A1位置的纯MLP应有的体量,虽然它只有两层,而且通道数只有32和64。除了减小通道数量,任何试图减小参数量的改动都将使其不再属于A1位置。比如说:

1) 先切块。把112×112的输入切成56×56的四块,每一块经过第一层变成28×28,再拼起来,这样第一层的参数量变成了3x56x56x32x28x28=236M,看起来好多了。但是,这破坏了全局性,因为分属于两块中的两点之间不再有联系了!换句话说,我们引入了一种局部性:一张图切成四块之后,每块中的任一像素只跟同块中的其他像素有联系。ViT,RepMLP和其他几篇MLP都用了这种操作或某种类似的操作。

2) 分组FC。正如卷积有分组卷积一样,FC也可以分组。由于torch里没有现成的算子,分组FC可以用分组1×1卷积实现。组数为g,参数量和计算量就会变成1/g。可惜,这也引入了局部性。RepMLP用了这种操作。

3) 把一个FC拆分成两次操作,第一次操作对channel维度线性重组,spatial共享参数(等价于1×1卷积);第二次操作对spatial维度线性重组,channel共享参数(等价于先转置后1×1卷积)。这思想可以类比于depthwise conv + 1×1 conv。MLP-Mixer使用这种操作,用两个各自都不具有全局性的操作实现了整体的全局性(而RepMLP使用另一种不同的机制,对不同的分块做pooling再连接,实现了这种全局性)。

所以,真正的100%纯MLP不太行,大家都在用各种花式操作做“伪MLP”的原因之一,就是体量太大。

这篇文章介绍的RepMLP属于B2的位置,不追求纯MLP。称其为“MLP”的原因是想强调卷积和FC的区别:RepMLP将卷积看成一种特殊的FC,显式地用卷积去强化FC(把FC变得具有局部性又不失全局性),指出了这样的FC强在哪里(如ResNet-50中,用一半通道数量的RepMLP替换3×3卷积就可以实现同等精度和55%加速),并用这种强化过的FC(及一些其他技巧)构造一种通用的CNN基本组件,提升多重任务性能。论文中说明了这里MLP的意思是推理时结构“不包含大于1×1的卷积”。

2. RepMLP:FC“内卷”,卷出性能

真正的100%纯MLP不太行的原因之二,是不具有局部先验。

在一张图片中,一个像素点跟它周围的像素点的关系往往比远在天边的另一个像素点更密切,这称为局部性。人类在识别图片的时候潜意识地利用这一点,称为局部先验。卷积网络符合局部先验,因为卷积核通过滑动窗口在图片上“一块一块地”寻找某种特征。

那么FC层呢?FC能自动学到这一点吗?在有限的数据量(ImageNet)和有限的计算资源前提(GPU)下,很难。

实验验证:下面我们假设FC层的输入是64x10x10的feature map直接 “展平”成的6400维向量。输出也是6400维向量,然后reshape成64x10x10的feature map。下图展示了FC学得的kernel中的一个切片的权值大小。简单地讲(详见论文),展示的这一部分表示在输出的第0个channel中随便找的一个采样点(6,6)(也就是图中黄框标出来的点)作用于第0个输入channel上的10×10个像素点的权值。颜色越深,表示权值越大。比如说,如果图中的(5,5)点颜色深,就表示这个FC层认为输出中的(0,6,6)点与输入中的(0,5,5)点关系紧密。

结果很明显,(6,6)周围的权值并没有颜色更深,也就是说FC并不认为这个点和周围点的联系更紧密。相反,似乎这个FC层认为(6,6)点与右上和右下部分关系更密切。实验也证明,不具有局部性的FC效果较差。

既然图像的局部性很强,FC把握不住,那怎么办呢?RepMLP提出,用卷积去增强FC(如下图所示,输入既被展平成向量并输入FC,又用不同大小的卷积核进行卷积,各自过BN后相加),并通过结构重参数化,将卷积融合到FC中去,从而在推理时去除卷积。

我们将卷积和FC之间建立联系,是因为卷积可以看成一个稀疏且存在重复参数的FC。如下图代码所示,给定输入X和卷积核conv_K,其卷积的结果等于X(直接展平成向量)和fc_K的矩阵乘,fc_K称为conv_K的等效FC核。尽管我们都相信这样的fc_K一定存在,但根据conv_K的值直接构造出fc_K的方法(下图中的convert_K函数)似乎不太简单。

本文提出了一种简洁优美的做法(见后文)。我们用这种方法构造出fc_K并打印出来,可以看出它是一个稀疏且有很多元素相同的矩阵(Toeplitz矩阵)。如下图的代码和结果所示。

RepMLP把卷积的输出和FC的输出相加,这样做的好处是:

1) 降低FLOPs,提高速度。用我们提出的方法把卷积全都转换为等效FC kernel后,由于矩阵乘法的可加性(AX + BX = (A+B)X),一个稀疏且共享参数的FC(Toeplitz矩阵)加一个不稀疏不共享参数的FC(全自由度的矩阵),可以等价转换为一个FC(其参数是这两个矩阵之和)。这样我们就可以将这些卷积等效地去掉。这一思路也属于结构重参数化(通过参数的等价转换实现结构的等价转换,如RepVGG)。

2) 在同等参数量的情况下,FC的FLOPs远低于卷积。

3) 相比于纯FC,这样做产生了局部性。注意这种局部性是我们“赋予”FC的,而不是让FC学到的。

4) 相比于卷积层,这样做使得相距遥远的两个点直接相连,具备了全局性。

这样做看起来像是让FC的“内部”含有卷积,所以也可以称为“内卷”。事实证明,跟人类相似,FC的“内卷”也可以提高性能。

只剩下一个问题了:我们相信存在一个FC kernel等价于卷积的卷积核,但是给定一个训练好的卷积核,怎么构造出FC kernel(Toeplitz矩阵)呢?

其实也很简单:FC kernel等于在单位矩阵reshape成的feature map上用卷积核做卷积的结果。这一做法是高效、可微、与具体的卷积算法和平台无关的。推导过程也很简洁(详见论文)。

现在,整个流程就很清晰了:

1) 训练时,既有FC又有卷积,输出相加。

2) 训练完成后,先把BN的参数“吸”到卷积核或FC中去(跟RepVGG一样),然后把每一个卷积转换成FC,把所有FC加到一起。从此以后,不再有卷积,只有FC。

3) 保存并部署转换后的模型。

现在我们再看一下用卷积增强后转换得到的FC kernel,可以看出采样点周围的权值变大了,现在(6,6)点更关注它旁边的输入点了。有趣的是,这里用到的最大卷积是7×7,但是7×7的范围(蓝色框)外还有一些值(红色框)比蓝框内的值大,这说明全局性也没有被局部性“淹没”。

一些其他设计

RepMLP中也用了一些其他设计,包括:

1) 用groupwise conv实现groupwise FC,减少参数和计算量。

2) 将输入分块(最近大家都会用的常见操作),进一步减少参数和计算量。如下图所示,H和W是feature map的分辨率,h和w是每一块的分辨率。

3) 用两个FC在不同分块之间建立联系,确保全局性。如下图所示。

实验结果

用RepMLP替换Res50中的部分结构,在ImageNet上有性能提升。将ImageNet pretrained模型迁移到语义分割和人脸上,也都有性能提升。

在ImageNet上的实验是在Res50中做的。考虑到Res50的主干通道较多(256、512、1024、2048),为了将RepMLP用到Res50中取得合理的trade-off,我们做了以下设计:

1)RepMLP Bottleneck Block:在RepMLP之前用1×1和3×3降维,RepMLP之后用3×3和1×1升维。这一结构类似于旷视在工程中探索并申请的专利GLFP(202010422194.X, Visual task processing method and device and electronic system,下图)。

2)RepMLP Light Block:在RepMLP之前用1×1大幅降维,之后用1×1大幅升维。降维/升维的幅度(8x)比Res50(4x)更大。

一些有趣的发现:

1) RepMLP中具有局部先验的成分(融合进FC的卷积),所以对于具有平移不变性的任务(ImageNet,Cityscapes语义分割)有效。

2) RepMLP中也具有不具有平移不变性的成分(大FC kernel),所以对于具有某种位置模式(例如人脸图像中,眼睛总是在鼻子上面)的任务也有效。

3) 由于FC和卷积的差别,RepMLP可以大幅增加参数而不降低速度(参数增加47%,ImageNet精度提升0.31%,速度仅降低2.2%)。

一些常见问题

RepMLP和ResMLP是什么关系?

相当于旺旺碎冰冰和王冰冰的关系。只是名字有点像。RepMLP中用卷积增强FC的思路也可以用在其他MLP架构中,应该也会有提升。另外,ResMLP、RepMLP和ResRep(去年做的一篇用重参数化做剪枝的论文)也没有关系。

把卷积融合进FC里,那FC不就是卷积了吗?

卷了,但不是完全卷,而且比卷积更强。上面可视化的图显示,转换后的kernel可以关注到卷积核的感受野以外的信息,因而表征能力更强。论文中报告的实验表明,这样的操作可以以一半的channel量达到与纯CNN相当的性能,速度更快,FLOPs更低。本文的关键也在于把卷积看成一种特殊的FC,然后考虑如何利用这种特殊性

所以MLP is all you need?

目前看来,还差得远。目前的方法多多少少都用到了切块等操作,都需要用某种方式降低参数量和引入局部性。真正的纯MLP(A1位置)依然还没有希望。真正纯MLP的一个大麻烦是总的参数量和输入分辨率耦合,因而改变输入分辨率会很困难。MLP-Mixer的一个缺点是不方便改变输入分辨率,所以它在ImageNet分类上的性能不容易迁移到其他任务上去。