{"id":7956,"date":"2022-09-24T16:10:00","date_gmt":"2022-09-24T08:10:00","guid":{"rendered":"http:\/\/139.9.1.231\/?p=7956"},"modified":"2022-10-08T20:18:21","modified_gmt":"2022-10-08T12:18:21","slug":"swin-mlp","status":"publish","type":"post","link":"http:\/\/139.9.1.231\/index.php\/2022\/09\/24\/swin-mlp\/","title":{"rendered":"Vision MLP &#8212;Swin-MLP"},"content":{"rendered":"\n<p>code:https:\/\/github.com\/microsoft\/Swin-Transformer<\/p>\n\n\n\n<p>Swin MLP \u4ee3\u7801\u6765\u81ea Swin Transformer \u7684\u5b98\u65b9\u5b9e\u73b0\u3002Swin Transformer \u4f5c\u8005\u4eec\u5728\u5df2\u6709\u6a21\u578b\u7684\u57fa\u7840\u4e0a\u5b9e\u73b0\u4e86 Swin MLP \u6a21\u578b\uff0c\u8bc1\u660e\u4e86 Window-based attention \u5bf9\u4e8e MLP \u6a21\u578b\u7684\u6709\u6548\u6027\u3002<\/p>\n\n\n\n<p><strong>\u628a\u5f20\u91cf (B, H, W, C) \u5206\u6210 window (B\u00d7H\/M\u00d7W\/M, M, M, C)<\/strong>\uff0c\u5176\u4e2dM\u662f window_size\u3002\u8fd9\u4e00\u6b65\u76f8\u5f53\u4e8e\u5f97\u5230 B\u00d7H\/M\u00d7W\/M \u4e2a\u5927\u5c0f\u4e3a (M, M, C) \u7684 window\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def window_partition(x, window_size):\r\n    \"\"\"\r\n    Args:\r\n        x: (B, H, W, C)\r\n        window_size (int): window size\r\n\r\n    Returns:\r\n        windows: (num_windows*B, window_size, window_size, C)\r\n    \"\"\"\r\n    B, H, W, C = x.shape\r\n    x = x.view(B, H \/\/ window_size, window_size, W \/\/ window_size, window_size, C)\r\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\r\n    return windows\r\n<\/code><\/pre>\n\n\n\n<p><strong>\u628a window (B\u00d7H\/M\u00d7W\/M, M, M, C) \u53d8\u56de\u5f20\u91cf (B, H, W, C)\u3002<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def window_reverse(windows, window_size, H, W):\r\n    \"\"\"\r\n    Args:\r\n        windows: (num_windows*B, window_size, window_size, C)\r\n        window_size (int): Window size\r\n        H (int): Height of image\r\n        W (int): Width of image\r\n\r\n    Returns:\r\n        x: (B, H, W, C)\r\n    \"\"\"\r\n    B = int(windows.shape&#91;0] \/ (H * W \/ window_size \/ window_size))\r\n    x = windows.view(B, H \/\/ window_size, W \/\/ window_size, window_size, window_size, -1)\r\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\r\n    return x\n<\/code><\/pre>\n\n\n\n<p><strong>\u4e00\u4e2a Swin MLP Block<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class SwinMLPBlock(nn.Module):\r\n    r\"\"\" Swin MLP Block.\r\n\r\n    Args:\r\n        dim (int): Number of input channels.\r\n        input_resolution (tuple&#91;int]): Input resolution.\r\n        num_heads (int): Number of attention heads.\r\n        window_size (int): Window size.\r\n        shift_size (int): Shift size for SW-MSA.\r\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\r\n        drop (float, optional): Dropout rate. Default: 0.0\r\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\r\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\r\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\r\n    \"\"\"\r\n\r\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\r\n                 mlp_ratio=4., drop=0., drop_path=0.,\r\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\r\n        super().__init__()\r\n        self.dim = dim\r\n        self.input_resolution = input_resolution\r\n        self.num_heads = num_heads\r\n        self.window_size = window_size\r\n        self.shift_size = shift_size\r\n        self.mlp_ratio = mlp_ratio\r\n        if min(self.input_resolution) &lt;= self.window_size:\r\n            # if window size is larger than input resolution, we don't partition windows\r\n            self.shift_size = 0\r\n            self.window_size = min(self.input_resolution)\r\n        assert 0 &lt;= self.shift_size &lt; self.window_size, \"shift_size must in 0-window_size\"\r\n\r\n        self.padding = &#91;self.window_size - self.shift_size, self.shift_size,\r\n                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b\r\n\r\n        self.norm1 = norm_layer(dim)\r\n        # use group convolution to implement multi-head MLP\r\n        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,\r\n                                     self.num_heads * self.window_size ** 2,\r\n                                     kernel_size=1,\r\n                                     groups=self.num_heads)\r\n\r\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\r\n        self.norm2 = norm_layer(dim)\r\n        mlp_hidden_dim = int(dim * mlp_ratio)\r\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\r\n\r\n    def forward(self, x):\r\n        H, W = self.input_resolution\r\n        B, L, C = x.shape\r\n        assert L == H * W, \"input feature has wrong size\"\r\n\r\n        shortcut = x\r\n        x = self.norm1(x)\r\n        x = x.view(B, H, W, C)\r\n\r\n        # shift\r\n        if self.shift_size > 0:\r\n            P_l, P_r, P_t, P_b = self.padding\r\n            shifted_x = F.pad(x, &#91;0, 0, P_l, P_r, P_t, P_b], \"constant\", 0)\r\n        else:\r\n            shifted_x = x\r\n        _, _H, _W, _ = shifted_x.shape\r\n\r\n        # partition windows\r\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\r\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\r\n\r\n        # Window\/Shifted-Window Spatial MLP\r\n        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C \/\/ self.num_heads)\r\n        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C\/\/nH\r\n        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,\r\n                                                  C \/\/ self.num_heads)\r\n        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C\/\/nH\r\n        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,\r\n                                                       C \/\/ self.num_heads).transpose(1, 2)\r\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)\r\n\r\n        # merge windows\r\n        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)\r\n        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C\r\n\r\n        # reverse shift\r\n        if self.shift_size > 0:\r\n            P_l, P_r, P_t, P_b = self.padding\r\n            x = shifted_x&#91;:, P_t:-P_b, P_l:-P_r, :].contiguous()\r\n        else:\r\n            x = shifted_x\r\n        x = x.view(B, H * W, C)\r\n\r\n        # FFN\r\n        x = shortcut + self.drop_path(x)\r\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\r\n\r\n        return x\r\n\r\n    def extra_repr(self) -> str:\r\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\r\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"<\/code><\/pre>\n\n\n\n<p>\u6ce8\u610f&nbsp;<strong>F.pad(x, [0, 0, P_l, P_r, P_t, P_b], &#8220;constant&#8221;, 0)&nbsp;<\/strong>\u7684\u5bf9\u8c61\u662f x\uff0c\u7ef4\u5ea6\u662f (B, H, W, C)\u3002<br>padding\u76f8\u5f53\u4e8e\u662f\u7b2c3\u7ef4 (C \u8fd9\u4e00\u7ef4) \u4e0d\u586b\u5145\uff0c\u7b2c2\u7ef4 (W \u8fd9\u4e00\u7ef4) \u5de6\u53f3\u5206\u522b\u586b\u5145 P_l, P_r\uff0c\u7b2c1\u7ef4 (H \u8fd9\u4e00\u7ef4) \u5de6\u53f3\u5206\u522b\u586b\u5145 P_t, P_b\u3002<br><strong>x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C\uff1a<\/strong><br>\u8fd9\u53e5\u4ee3\u7801\u628a shifted_x \u5206\u6210 nW*B \u4e2a windows\uff0c\u5176\u4e2d\u6bcf\u4e2a window \u7684\u7ef4\u5ea6\u662f (window_size, window_size, C)\u3002<\/p>\n\n\n\n<p># reverse shift<br>if self.shift_size > 0:<br>P_l, P_r, P_t, P_b = self.padding<br>x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()<br>else:<br>x = shifted_x<br>\u8fd9\u91cc\u662f\u5982\u679c\u8fdb\u884c\u4e86 shift \u64cd\u4f5c\uff0c\u5219\u6700\u540e\u53d6\u5f97\u7ed3\u679c\u4e5f\u5e94\u8be5\u662f\u6ca1\u6709 padding \u7684\u90e8\u5206\uff0c\u6b63\u597d\u662f shifted_x[:, P_t:-P_b, P_l:-P_r, :]\u3002<\/p>\n\n\n\n<p><strong>\u4e00\u4e2a Swin MLP Block \u7684 FLOPs<\/strong>\uff0c\u6ce8\u610f WSA \u7684\u8ba1\u7b97\u91cf\u662f\uff1a<\/p>\n\n\n\n<p><strong>FLOPs (WSA) = (window_size * window_size)^2 * dim * number_window<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def flops(self):\r\n        flops = 0\r\n        H, W = self.input_resolution\r\n        # norm1\r\n        flops += self.dim * H * W\r\n\r\n        # Window\/Shifted-Window Spatial MLP\r\n        if self.shift_size > 0:\r\n            nW = (H \/ self.window_size + 1) * (W \/ self.window_size + 1)\r\n        else:\r\n            nW = H * W \/ self.window_size \/ self.window_size\r\n        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)\r\n        # mlp\r\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\r\n        # norm2\r\n        flops += self.dim * H * W\r\n        return flops<\/code><\/pre>\n\n\n\n<p><strong>\u6bcf\u4e2a stage \u4e4b\u95f4\u7684 PatchMerging\u8fde\u63a5<\/strong>\uff0c\u628a resolution \u53d8\u4e3a\u4e00\u534a\uff0cdim \u53d8\u4e3a2\u500d\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class PatchMerging(nn.Module):\n    r\"\"\" Patch Merging Layer.\n\n    Args:\n        input_resolution (tuple&#91;int]): Resolution of input feature.\n        dim (int): Number of input channels.\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x&#91;:, 0::2, 0::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x1 = x&#91;:, 1::2, 0::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x2 = x&#91;:, 0::2, 1::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x3 = x&#91;:, 1::2, 1::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x = torch.cat(&#91;x0, x1, x2, x3], -1)  <em># B H\/2 W\/2 4*C<\/em>\n        x = x.view(B, -1, 4 * C)  <em># B H\/2*W\/2 4*C<\/em>\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x\n\n    def flops(self):\n        H, W = self.input_resolution\n        <em># norm<\/em>\n        flops = H * W * self.dim\n        <em># reduction<\/em>\n        flops += (H \/\/ 2) * (W \/\/ 2) * 4 * self.dim * 2 * self.dim\n        return flops<\/code><\/pre>\n\n\n\n<ul><li>Patch Merging \u64cd\u4f5c\u628a\u76f8\u90bb\u7684 2\u00d72 \u4e2a tokens \u7ed9\u5408\u5e76\u5230\u4e00\u8d77\uff0c\u5f97\u5230\u7684 token \u7684\u7ef4\u5ea6\u662f4C\u3002<br>Patch Merging \u64cd\u4f5c\u518d\u901a\u8fc7\u4e00\u6b21\u7ebf\u6027\u53d8\u6362\u628a\u7ef4\u5ea6\u964d\u4e3a2C\u3002<\/li><\/ul>\n\n\n\n<p><strong>\u4e00\u4e2a Swin MLP Layer<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BasicLayer(nn.Module):\n    \"\"\" A basic Swin MLP layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple&#91;int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        drop (float, optional): Dropout rate. Default: 0.0\n        drop_path (float | tuple&#91;float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., drop=0., drop_path=0.,\n                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n\n        <em># build blocks<\/em>\n        self.blocks = nn.ModuleList(&#91;\n            SwinMLPBlock(dim=dim, input_resolution=input_resolution,\n                         num_heads=num_heads, window_size=window_size,\n                         shift_size=0 if (i % 2 == 0) else window_size \/\/ 2,\n                         mlp_ratio=mlp_ratio,\n                         drop=drop,\n                         drop_path=drop_path&#91;i] if isinstance(drop_path, list) else drop_path,\n                         norm_layer=norm_layer)\n            for i in range(depth)])\n\n        <em># patch merging layer<\/em>\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -&gt; str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops<\/code><\/pre>\n\n\n\n<ul><li>\u5305\u542b depth \u4e2a Swin MLP Block\u3002<br>\u6ce8\u610f\u8ba1\u7b97 FLOPs \u7684\u65b9\u5f0f\uff1a\u6bcf\u4e2a blk \u548c downsample \u90fd\u81ea\u5e26 flops() \u65b9\u6cd5\uff0c\u53ef\u4ee5\u76f4\u63a5\u6765\u8c03\u7528\u3002<\/li><\/ul>\n\n\n\n<p><strong>PatchEmbedded \u64cd\u4f5c<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class PatchEmbed(nn.Module):\n    r\"\"\" Image to Patch Embedding\n\n    Args:\n        img_size (int): Image size.  Default: 224.\n        patch_size (int): Patch token size. Default: 4.\n        in_chans (int): Number of input image channels. Default: 3.\n        embed_dim (int): Number of linear projection output channels. Default: 96.\n        norm_layer (nn.Module, optional): Normalization layer. Default: None\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = &#91;img_size&#91;0] \/\/ patch_size&#91;0], img_size&#91;1] \/\/ patch_size&#91;1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution&#91;0] * patches_resolution&#91;1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        <em># FIXME look at relaxing size constraints<\/em>\n        assert H == self.img_size&#91;0] and W == self.img_size&#91;1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size&#91;0]}*{self.img_size&#91;1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  <em># B Ph*Pw C<\/em>\n        if self.norm is not None:\n            x = self.norm(x)\n        return x\n\n    def flops(self):\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size&#91;0] * self.patch_size&#91;1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops<\/code><\/pre>\n\n\n\n<ul><li>\u548c ViT \u7684 Patch Embedded \u64cd\u4f5c\u4e00\u6837\uff0c\u672c\u8d28\u4e0a\u662f\u4e00\u4e2a K=patch size\uff0cs=patch size \u7684 nn.Conv2d \u64cd\u4f5c\uff0c\u6ce8\u610f\u5377\u79ef FLOPs \u7684\u8ba1\u7b97\u516c\u5f0f\u5373\u53ef\u3002<\/li><\/ul>\n\n\n\n<p><strong>SwinMLP \u6574\u4f53\u6a21\u578b\u67b6\u6784<\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class SwinMLP(nn.Module):\n    r\"\"\" Swin MLP\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin MLP layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        drop_rate (float): Dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=&#91;2, 2, 6, 2], num_heads=&#91;3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.ape = ape\n        self.patch_norm = patch_norm\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n        self.mlp_ratio = mlp_ratio\n\n        # split image into non-overlapping patches\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        # absolute position embedding\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        # stochastic depth\n        dpr = &#91;x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule\n\n        # build layers\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution&#91;0] \/\/ (2 ** i_layer),\n                                                 patches_resolution&#91;1] \/\/ (2 ** i_layer)),\n                               depth=depths&#91;i_layer],\n                               num_heads=num_heads&#91;i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               drop=drop_rate,\n                               drop_path=dpr&#91;sum(depths&#91;:i_layer]):sum(depths&#91;:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer &lt; self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features)\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes &gt; 0 else nn.Identity()\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, (nn.Linear, nn.Conv1d)):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  # B L C\n        x = self.avgpool(x.transpose(1, 2))  # B C 1\n        x = torch.flatten(x, 1)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        # adaptive average pool\n        flops += self.num_features * self.patches_resolution&#91;0] * self.patches_resolution&#91;1] \/\/ (2 ** self.num_layers)\n        # head\n        flops += self.num_features * self.num_classes\n        return flops<\/code><\/pre>\n\n\n\n<ul><li>\u75314\u4e2a Stage \u7ec4\u6210\uff0c\u6bcf\u4e2a Stage \u7531 BasicLayer \u5b9e\u73b0\u3002<br>\u4f20\u5165\u7684 depths \u4ee3\u8868\u6bcf\u4e2a Stage \u7684\u5c42\u6570\uff0c\u6bd4\u5982 Swin-T \u5c31\u662f\uff1a[2, 2, 6, 2]\u3002<\/li><\/ul>\n","protected":false},"excerpt":{"rendered":"<p>code:https:\/\/github.com\/microsoft\/Swin-Transformer Swin &hellip; <a href=\"http:\/\/139.9.1.231\/index.php\/2022\/09\/24\/swin-mlp\/\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">Vision MLP &#8212;Swin-MLP<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":[],"categories":[30,4,9],"tags":[],"_links":{"self":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/7956"}],"collection":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/comments?post=7956"}],"version-history":[{"count":4,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/7956\/revisions"}],"predecessor-version":[{"id":8042,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/7956\/revisions\/8042"}],"wp:attachment":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/media?parent=7956"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/categories?post=7956"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/tags?post=7956"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}