{"id":8652,"date":"2022-10-04T10:43:51","date_gmt":"2022-10-04T02:43:51","guid":{"rendered":"http:\/\/139.9.1.231\/?p=8652"},"modified":"2022-10-06T17:27:24","modified_gmt":"2022-10-06T09:27:24","slug":"swin-transformer-code","status":"publish","type":"post","link":"http:\/\/139.9.1.231\/index.php\/2022\/10\/04\/swin-transformer-code\/","title":{"rendered":"Swin Transformer \u4ee3\u7801\u8be6\u89e3"},"content":{"rendered":"\n<p class=\"has-light-pink-background-color has-background\"><strong>code\uff1a<a href=\"https:\/\/github.com\/microsoft\/Swin-Transformer\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/github.com\/microsoft\/Swin-Transformer<\/a><\/strong><\/p>\n\n\n\n<p class=\"has-bright-blue-background-color has-background\"><strong>\u4ee3\u7801\u8be6\u89e3\uff1a <a href=\"https:\/\/zhuanlan.zhihu.com\/p\/367111046\">https:\/\/zhuanlan.zhihu.com\/p\/367111046<\/a><\/strong><\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"278\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1024x278.png\" alt=\"\" class=\"wp-image-8713\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1024x278.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-300x82.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-768x209.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image.png 1284w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u9884\u5904\u7406\uff1a<\/p>\n\n\n\n<p>\u5bf9\u4e8e\u5206\u7c7b\u6a21\u578b\uff0c\u8f93\u5165\u56fe\u50cf\u5c3a\u5bf8\u4e3a&nbsp;224\u00d7224\u00d73&nbsp;\uff0c\u5373&nbsp;H=W=224&nbsp;\u3002\u6309\u7167\u539f\u6587\u63cf\u8ff0\uff0c\u6a21\u578b\u5148\u5c06\u56fe\u50cf\u5206\u5272\u6210\u6bcf\u5757\u5927\u5c0f\u4e3a&nbsp;4\u00d74&nbsp;\u7684patch\uff0c\u90a3\u4e48\u5c31\u4f1a\u6709&nbsp;56\u00d756&nbsp;\u4e2apatch\uff0c\u8fd9\u5c31\u662f\u521d\u59cbresolution\uff0c\u4e5f\u662f\u540e\u9762\u6bcf\u4e2astage\u4f1a\u964d\u91c7\u6837\u7684\u7ef4\u5ea6\u3002\u540e\u9762\u6bcf\u4e2astage\u90fd\u4f1a\u964d\u91c7\u6837\u65f6\u957f\u5bbd\u964d\u5230\u4e00\u534a\uff0c\u7279\u5f81\u6570\u52a0\u500d\u3002\u6309\u7167\u539f\u6587\u53ca\u539f\u56fe\u63cf\u8ff0\uff0c\u5212\u5206\u7684\u6bcf\u4e2apatch\u5177\u6709&nbsp;4\u00d74\u00d73=48&nbsp;\u7ef4\u7279\u5f81\u3002<\/p>\n\n\n\n<ul><li>\u5b9e\u9645\u5728\u4ee3\u7801\u4e2d\uff0c\u9996\u5148\u4f7f\u7528\u4e86PatchEmbed\u6a21\u5757\uff08\u8fd9\u91cc\u7684PatchEmbed\u5305\u62ec\u4e0a\u56fe\u4e2d\u7684Linear Embedding \u548c patch partition\u5c42\uff09\uff0c\u5b9a\u4e49\u5982\u4e0b\uff1a<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>class PatchEmbed(nn.Module):\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): # embed_dim\u5c31\u662f\u4e0a\u56fe\u4e2d\u7684C\u8d85\u53c2\u6570\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        patches_resolution = &#91;img_size&#91;0] \/\/ patch_size&#91;0], img_size&#91;1] \/\/ patch_size&#91;1]]\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.patches_resolution = patches_resolution\n        self.num_patches = patches_resolution&#91;0] * patches_resolution&#91;1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        <em># FIXME look at relaxing size constraints<\/em>\n        assert H == self.img_size&#91;0] and W == self.img_size&#91;1], \\\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size&#91;0]}*{self.img_size&#91;1]}).\"\n        x = self.proj(x).flatten(2).transpose(1, 2)  <em># B Ph*Pw C<\/em>\n        if self.norm is not None:\n            x = self.norm(x)\n        return x<\/code><\/pre>\n\n\n\n<p>\u53ef\u4ee5\u770b\u5230\uff0c\u5b9e\u9645\u64cd\u4f5c\u4f7f\u7528\u4e86\u4e00\u4e2a\u5377\u79ef\u5c42conv2d(3, 96, 4, 4)\uff0c\u76f4\u63a5\u5c31\u505a\u4e86\u5212\u5206patch\u548c\u7f16\u7801\u521d\u59cb\u7279\u5f81\u7684\u5de5\u4f5c\uff0c\u5bf9\u4e8e\u8f93\u5165&nbsp;x:B\u00d73\u00d7224\u00d7224&nbsp;\uff0c\u7ecf\u8fc7\u4e00\u5c42conv2d\u548cLayerNorm\u5f97\u5230&nbsp;x:B\u00d7562\u00d796&nbsp;\u3002\u7136\u540e\u4f5c\u4e3a\u5bf9\u6bd4\uff0c\u53ef\u4ee5\u9009\u62e9\u6027\u5730\u52a0\u4e0a\u6bcf\u4e2apatch\u7684\u7edd\u5bf9\u4f4d\u7f6e\u7f16\u7801\uff0c\u539f\u6587\u5b9e\u9a8c\u8868\u793a\u8fd9\u79cd\u505a\u6cd5\u4e0d\u597d\uff0c\u56e0\u6b64\u4e0d\u4f1a\u91c7\u7528\uff08ape=false\uff09\u3002\u6700\u540e\u7ecf\u8fc7\u4e00\u5c42dropout\uff0c\u81f3\u6b64\uff0c\u9884\u5904\u7406\u5b8c\u6210\u3002\u53e6\u5916\uff0c\u8981\u6ce8\u610f\u7684\u662f\uff0c\u4ee3\u7801\u548c\u4e0a\u9762\u6d41\u7a0b\u56fe\u5e76\u4e0d\u7b26\uff0c\u5176\u5b9e\u5728stage 1\u4e4b\u524d\uff0c\u5373\u9884\u5904\u7406\u5b8c\u6210\u540e\uff0c\u7ef4\u5ea6\u5df2\u7ecf\u662f&nbsp;H\/4\u00d7W\/4\u00d7C&nbsp;\uff0cstage 1\u4e4b\u540e\u5df2\u7ecf\u662f&nbsp;H\/8\u00d7W\/8\u00d72C&nbsp;\uff0c\u4e0d\u8fc7\u5728stage 4\u540e\u4e0d\u518d\u964d\u91c7\u6837\uff0c\u5f97\u5230\u7684\u8fd8\u662f&nbsp;H\/32\u00d7W\/32\u00d78C&nbsp;\u3002<\/p>\n\n\n\n<h3>stage\u5904\u7406<\/h3>\n\n\n\n<p>\u6211\u4eec\u5148\u68b3\u7406\u6574\u4e2astage\u7684\u5927\u4f53\u8fc7\u7a0b\uff0c\u628a\u7b80\u5355\u7684\u90e8\u5206\u5148\u8bf4\u4e86\uff0c\u518d\u6df1\u5165\u5230\u590d\u6742\u5f97\u7684\u7ec6\u8282\u3002\u6bcf\u4e2astage\uff0c\u5373\u4ee3\u7801\u4e2d\u7684BasicLayer\uff0c\u7531\u82e5\u5e72\u4e2ablock\u7ec4\u6210\uff0c\u800cblock\u7684\u6570\u76ee\u7531depth\u5217\u8868\u4e2d\u7684\u5143\u7d20\u51b3\u5b9a\u3002\u6bcf\u4e2ablock\u5c31\u662fW-MSA\uff08window-multihead self attention\uff09\u6216\u8005SW-MSA\uff08shift window multihead self attention\uff09\uff0c\u4e00\u822c\u6709\u5076\u6570\u4e2ablock\uff0c\u4e24\u79cdSA\u4ea4\u66ff\u51fa\u73b0\uff0c\u6bd4\u59826\u4e2ablock\uff0c0\uff0c2\uff0c4\u662fW-MSA\uff0c1\uff0c3\uff0c5\u662fSW-MSA\u3002\u5728\u7ecf\u5386\u5b8c\u4e00\u4e2astage\u540e\uff0c\u4f1a\u8fdb\u884c\u4e0b\u91c7\u6837\uff0c\u5b9a\u4e49\u7684\u4e0b\u91c7\u6837\u6bd4\u8f83\u6709\u610f\u601d\u3002\u6bd4\u5982\u8fd8\u662f&nbsp;56\u00d756&nbsp;\u4e2apatch\uff0c\u56db\u4e2a\u4e3a\u4e00\u7ec4\uff0c\u5206\u522b\u53d6\u6bcf\u7ec4\u4e2d\u7684\u5de6\u4e0a\uff0c\u53f3\u4e0a\u3001\u5de6\u4e0b\u3001\u53f3\u4e0b\u5806\u53e0\u4e00\u8d77\uff0c\u7ecf\u8fc7\u4e00\u4e2alayernorm\uff0clinear\u5c42\uff0c\u5b9e\u73b0\u7ef4\u5ea6\u4e0b\u91c7\u6837\u3001\u7279\u5f81\u52a0\u500d\u7684\u6548\u679c\u3002\u5b9e\u9645\u4e0a\u5b83\u53ef\u4ee5\u770b\u6210\u4e00\u79cd<strong>\u52a0\u6743\u6c60\u5316\u7684\u8fc7\u7a0b<\/strong>\u3002\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class PatchMerging(nn.Module):\n    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.dim = dim\n        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n        self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        \"\"\"\n        x: B, H*W, C\n        \"\"\"\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n        assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n\n        x = x.view(B, H, W, C)\n\n        x0 = x&#91;:, 0::2, 0::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x1 = x&#91;:, 1::2, 0::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x2 = x&#91;:, 0::2, 1::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x3 = x&#91;:, 1::2, 1::2, :]  <em># B H\/2 W\/2 C<\/em>\n        x = torch.cat(&#91;x0, x1, x2, x3], -1)  <em># B H\/2 W\/2 4*C<\/em>\n        x = x.view(B, -1, 4 * C)  <em># B H\/2*W\/2 4*C<\/em>\n\n        x = self.norm(x)\n        x = self.reduction(x)\n\n        return x<\/code><\/pre>\n\n\n\n<p>\u5728\u7ecf\u5386\u5b8c4\u4e2astage\u540e\uff0c\u5f97\u5230\u7684\u662f&nbsp;(H\/32\u00d7W\/32)\u00d78C&nbsp;\u7684\u7279\u5f81\uff0c\u5c06\u5176\u8f6c\u5230&nbsp;8C\u00d7(H\/32\u00d7W\/32)&nbsp;\u540e\uff0c\u63a5\u4e00\u4e2aAdaptiveAvgPool1d(1)\uff0c\u5168\u5c40\u5e73\u5747\u6c60\u5316\uff0c\u5f97\u5230&nbsp;8C&nbsp;\u7279\u5f81\uff0c\u6700\u540e\u63a5\u4e00\u4e2a\u5206\u7c7b\u5668\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"360\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1-1024x360.png\" alt=\"\" class=\"wp-image-8727\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1-1024x360.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1-300x106.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1-768x270.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-1.png 1501w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><figcaption> PatchMerging <\/figcaption><\/figure>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"351\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-2-1024x351.png\" alt=\"\" class=\"wp-image-8731\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-2-1024x351.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-2-300x103.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-2-768x264.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-2.png 1075w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<h3>Block\u5904\u7406<\/h3>\n\n\n\n<p>SwinTransformerBlock\u7684\u7ed3\u6784\uff0c\u7531LayerNorm\u5c42\u3001windowAttention\u5c42\uff08Window MultiHead self -attention\uff0c W-MSA\uff09\u3001MLP\u5c42\u4ee5\u53cashiftWindowAttention\u5c42\uff08SW-MSA\uff09\u7ec4\u6210\u3002<\/p>\n\n\n\n<div class=\"wp-block-image\"><figure class=\"aligncenter size-full\"><img loading=\"lazy\" width=\"385\" height=\"430\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-5.png\" alt=\"\" class=\"wp-image-8740\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-5.png 385w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-5-269x300.png 269w\" sizes=\"(max-width: 385px) 100vw, 385px\" \/><\/figure><\/div>\n\n\n\n<p>\u4e0a\u9762\u8bf4\u5230\u6709\u4e24\u79cdblock\uff0cblock\u7684\u4ee3\u7801\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class SwinTransformerBlock(nn.Module):\n    r\"\"\" Swin Transformer Block.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple&#91;int]): Input resulotion.\n        num_heads (int): Number of attention heads.\n        window_size (int): Window size.\n        shift_size (int): Shift size for SW-MSA.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float, optional): Stochastic depth rate. Default: 0.0\n        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        if min(self.input_resolution) &lt;= self.window_size:\n            # if window size is larger than input resolution, we don't partition windows\n            self.shift_size = 0\n            self.window_size = min(self.input_resolution)\n        assert 0 &lt;= self.shift_size &lt; self.window_size, \"shift_size must in 0-window_size\"\n\n        # \u5de6\u56fe\u4e2d\u6700\u4e0b\u8fb9\u7684LN\u5c42layerNorm\u5c42\n        self.norm1 = norm_layer(dim)\n        # W_MSA\u5c42\u6216\u8005SW-MSA\u5c42\uff0c\u8be6\u7ec6\u7684\u4ecb\u7ecd\u770bWindowAttention\u90e8\u5206\u7684\u4ee3\u7801\n        self.attn = WindowAttention(\n            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        self.drop_path = DropPath(drop_path) if drop_path &gt; 0. else nn.Identity()\n        # \u5de6\u56fe\u4e2d\u95f4\u90e8\u5206\u7684LN\u5c42\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        # \u5de6\u56fe\u6700\u4e0a\u8fb9\u7684MLP\u5c42\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        # \u8fd9\u91cc\u5229\u7528shift_size\u63a7\u5236\u662f\u5426\u6267\u884cshift window\u64cd\u4f5c\n        # \u5f53shift_size\u4e3a0\u65f6\uff0c\u4e0d\u6267\u884cshift\u64cd\u4f5c\uff0c\u5bf9\u5e94W-MSA\uff0c\u4e5f\u5c31\u662f\u5728\u6bcf\u4e2astage\u4e2d,W-MSA\u4e0eSW-MSA\u4ea4\u66ff\u51fa\u73b0\n        # \u4f8b\u5982\u7b2c\u4e00\u4e2astage\u4e2d\u5b58\u5728\u4e24\u4e2ablock\uff0c\u90a3\u4e48\u7b2c\u4e00\u4e2ashift_size=0\u5c31\u662fW-MSA\uff0c\u7b2c\u4e8c\u4e2ashift_size\u4e0d\u4e3a0\n        # \u5c31\u662fSW-MSA\n        if self.shift_size &gt; 0:\n            # calculate attention mask for SW-MSA\n            H, W = self.input_resolution\n            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1\n#slice() \u51fd\u6570\u5b9e\u73b0\u5207\u7247\u5bf9\u8c61\uff0c\u4e3b\u8981\u7528\u5728\u5207\u7247\u64cd\u4f5c\u51fd\u6570\u91cc\u7684\u53c2\u6570\u4f20\u9012\u3002class slice(start, stop&#91;, step])\n            h_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            w_slices = (slice(0, -self.window_size),\n                        slice(-self.window_size, -self.shift_size),\n                        slice(-self.shift_size, None))\n            cnt = 0\n            for h in h_slices:\n                for w in w_slices:\n                    img_mask&#91;:, h, w, :] = cnt\n                    cnt += 1\n## \u4e0a\u8ff0\u64cd\u4f5c\u662f\u4e3a\u4e86\u7ed9\u6bcf\u4e2a\u7a97\u53e3\u7ed9\u4e0a\u7d22\u5f15\n\n            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n        else:\n            attn_mask = None\n\n        self.register_buffer(\"attn_mask\", attn_mask)\n\n    def forward(self, x):\n        H, W = self.input_resolution\n        B, L, C = x.shape\n        assert L == H * W, \"input feature has wrong size\"\n\n        shortcut = x\n        x = self.norm1(x)\n        x = x.view(B, H, W, C)\n\n        # cyclic shift\n        # \u5982\u679c\u9700\u8981\u8ba1\u7b97 SW-MSA\u5c31\u9700\u8981\u8fdb\u884c\u5faa\u73af\u79fb\u4f4d\u3002\n        if self.shift_size &gt; 0:\n            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n        else:\n            shifted_x = x\n\n        # partition windows\n        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C\n        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C\n\n        # W-MSA\/SW-MSA\n        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C\n\n        # merge windows\n        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C\n\n        # reverse cyclic shift\n        if self.shift_size &gt; 0:\n#shifts (python:int \u6216 tuple of python:int) \u2014\u2014 \u5f20\u91cf\u5143\u7d20\u79fb\u4f4d\u7684\u4f4d\u6570\u3002\u5982\u679c\u8be5\u53c2\u6570\u662f\u4e00\u4e2a\u5143\u7ec4\uff08\u4f8b\u5982shifts=(x,y)\uff09\uff0cdims\u5fc5\u987b\u662f\u4e00\u4e2a\u76f8\u540c\u5927\u5c0f\u7684\u5143\u7ec4\uff08\u4f8b\u5982dims=(a,b)\uff09\uff0c\u76f8\u5f53\u4e8e\u5728\u7b2ca\u7ef4\u5ea6\u79fbx\u4f4d\uff0c\u5728b\u7ef4\u5ea6\u79fby\u4f4d\n            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n        else:\n            x = shifted_x\n        x = x.view(B, H * W, C)\n\n        # FFN\n        x = shortcut + self.drop_path(x)\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n        return x\n\n    def extra_repr(self) -&gt; str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n               f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n\n    def flops(self):\n        flops = 0\n        H, W = self.input_resolution\n        # norm1\n        flops += self.dim * H * W\n        # W-MSA\/SW-MSA\n        nW = H * W \/ self.window_size \/ self.window_size\n        flops += nW * self.attn.flops(self.window_size * self.window_size)\n        # mlp\n        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio\n        # norm2\n        flops += self.dim * H * W\n        return flops<\/code><\/pre>\n\n\n\n<h3>W-MSA<\/h3>\n\n\n\n<p>W-MSA\u6bd4\u8f83\u7b80\u5355\uff0c\u53ea\u8981\u5176\u4e2d<code>shift_size<\/code>\u8bbe\u7f6e\u4e3a0\u5c31\u662fW-MSA\u3002\u4e0b\u9762\u8ddf\u7740\u4ee3\u7801\u8d70\u4e00\u904d\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<ul><li>\u8f93\u5165\uff1a&nbsp;x:B\u00d7562\u00d796&nbsp;\uff0c&nbsp;H,W=56<\/li><li>\u7ecf\u8fc7\u4e00\u5c42layerNorm<\/li><li>\u53d8\u5f62\uff1a&nbsp;x:B\u00d756\u00d756\u00d796<\/li><li>\u76f4\u63a5\u8d4b\u503c\u7ed9<code>shifted_x<\/code><\/li><li>\u8c03\u7528<code>window_partition<\/code>\u51fd\u6570\uff0c\u8f93\u5165<code>shifted_x<\/code>\uff0c<code>window_size=7<\/code>\uff1a<\/li><li>\u6ce8\u610f\u7a97\u53e3\u5927\u5c0f\u4ee5patch\u4e3a\u5355\u4f4d\uff0c\u6bd4\u59827\u5c31\u662f7\u4e2apatch\uff0c\u5982\u679c56\u7684\u5206\u8fa8\u7387\u5c31\u4f1a\u67098\u4e2a\u7a97\u53e3\u3002<\/li><li>\u8fd9\u4e2a\u51fd\u6570\u5bf9<code>shifted_x<\/code>\u505a\u4e00\u7cfb\u5217\u53d8\u5f62\uff0c\u6700\u7ec8\u53d8\u6210&nbsp;82B\u00d77\u00d77\u00d796<\/li><li>\u8fd4\u56de\u8d4b\u503c\u7ed9<code>x_windows<\/code>\uff0c\u518d\u53d8\u5f62\u6210&nbsp;82B\u00d772\u00d796&nbsp;\uff0c\u8fd9\u8868\u793a\u6240\u6709\u56fe\u7247\uff0c\u6bcf\u4e2a\u56fe\u7247\u768464\u4e2awindow\uff0c\u6bcf\u4e2awindow\u5185\u670949\u4e2apatch\u3002<\/li><li>\u8c03\u7528<code>WindowAttention<\/code>\u5c42\uff0c\u8fd9\u91cc\u4ee5\u5b83\u7684<code>num_head<\/code>\u4e3a3\u4e3a\u4f8b\u3002\u8f93\u5165\u53c2\u6570\u4e3a<code>x_windows<\/code>\u548c<code>self.attn_mask<\/code>\uff0c\u5bf9\u4e8eW-MSA\uff0c<code>attn_mask<\/code>\u4e3aNone\uff0c\u53ef\u4ee5\u4e0d\u7528\u7ba1\u3002<\/li><\/ul>\n\n\n\n<p><\/p>\n\n\n\n<h4><code>WindowAttention<\/code>\u4ee3\u7801\u5982\u4e0b\uff1a<\/h4>\n\n\n\n<p>\u4ee3\u7801\u4e2d\u4f7f\u75287&#215;7\u7684windowsize\uff0c\u5c06feature map\u5206\u5272\u4e3a\u4e0d\u540c\u7684window\uff0c\u5728\u6bcf\u4e2awindow\u4e2d\u8ba1\u7b97\u81ea\u6ce8\u610f\u529b\u3002<\/p>\n\n\n\n<p>Self-attention\u7684\u8ba1\u7b97\u516c\u5f0f\uff08B\u4e3a\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\uff09<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"139\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-6-1024x139.png\" alt=\"\" class=\"wp-image-8761\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-6-1024x139.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-6-300x41.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-6-768x104.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-6.png 1177w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u7edd\u5bf9\u4f4d\u7f6e\u7f16\u7801\u662f\u5728\u8fdb\u884cself-attention\u8ba1\u7b97\u4e4b\u524d\u4e3a\u6bcf\u4e00\u4e2atoken\u6dfb\u52a0\u4e00\u4e2a\u53ef\u5b66\u4e60\u7684\u53c2\u6570\uff0c\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u5982\u4e0a\u5f0f\u6240\u793a\uff0c\u662f\u5728\u8fdb\u884cself-attention\u8ba1\u7b97\u65f6\uff0c\u5728\u8ba1\u7b97\u8fc7\u7a0b\u4e2d\u6dfb\u52a0\u4e00\u4e2a\u53ef\u5b66\u4e60\u7684\u76f8\u5bf9\u4f4d\u7f6e\u53c2\u6570\u3002<\/p>\n\n\n\n<p>\u5047\u8bbewindow_size = 2*2\u5373\u6bcf\u4e2a\u7a97\u53e3\u67094\u4e2atoken\u00a0(M=2)\u00a0\uff0c\u5982\u56fe1\u6240\u793a\uff0c\u5728\u8ba1\u7b97self-attention\u65f6\uff0c\u6bcf\u4e2atoken\u90fd\u8981\u4e0e\u6240\u6709\u7684token\u8ba1\u7b97QK\u503c\uff0c\u5982\u56fe6\u6240\u793a\uff0c\u5f53\u4f4d\u7f6e1\u7684token\u8ba1\u7b97self-attention\u65f6\uff0c\u8981\u8ba1\u7b97\u4f4d\u7f6e1\u4e0e\u4f4d\u7f6e(1,2,3,4)\u7684QK\u503c\uff0c\u5373\u4ee5\u4f4d\u7f6e1\u7684token\u4e3a\u4e2d\u5fc3\u70b9\uff0c\u4e2d\u5fc3\u70b9\u4f4d\u7f6e\u5750\u6807(0,0)\uff0c\u5176\u4ed6\u4f4d\u7f6e\u8ba1\u7b97\u4e0e\u5f53\u524d\u4f4d\u7f6e\u5750\u6807\u7684\u504f\u79fb\u91cf\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-full is-resized\"><img loading=\"lazy\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-7.png\" alt=\"\" class=\"wp-image-8796\" width=\"678\" height=\"279\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-7.png 714w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-7-300x124.png 300w\" sizes=\"(max-width: 678px) 100vw, 678px\" \/><figcaption>\u5750\u6807\u53d8\u6362<\/figcaption><\/figure>\n\n\n\n<figure class=\"wp-block-image size-full is-resized\"><img loading=\"lazy\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-8.png\" alt=\"\" class=\"wp-image-8798\" width=\"651\" height=\"249\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-8.png 643w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-8-300x115.png 300w\" sizes=\"(max-width: 651px) 100vw, 651px\" \/><figcaption>\u5750\u6807\u53d8\u6362<\/figcaption><\/figure>\n\n\n\n<figure class=\"wp-block-image size-full is-resized\"><img loading=\"lazy\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-9.png\" alt=\"\" class=\"wp-image-8801\" width=\"653\" height=\"246\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-9.png 711w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-9-300x113.png 300w\" sizes=\"(max-width: 653px) 100vw, 653px\" \/><\/figure>\n\n\n\n<figure class=\"wp-block-image is-resized\"><img loading=\"lazy\" src=\"https:\/\/pic1.zhimg.com\/v2-bcda7ffdbf89cdc5aa90f6ba07e4c044_r.jpg\" alt=\"\" width=\"690\" height=\"603\"\/><figcaption>\u76f8\u5bf9\u4f4d\u7f6e\u7d22\u5f15\u6c42\u89e3\u6d41\u7a0b\u56fe<\/figcaption><\/figure>\n\n\n\n<p>\u6700\u540e\u751f\u6210\u7684\u662f\u76f8\u5bf9\u4f4d\u7f6e\u7d22\u5f15,relative_position_index.shape =\u00a0(M2\uff0cM2)\u00a0\uff0c\u5728\u7f51\u7edc\u4e2d\u6ce8\u518c\u6210\u4e3a\u4e00\u4e2a\u4e0d\u53ef\u5b66\u4e60\u7684\u53d8\u91cf\uff0crelative_position_index\u7684\u4f5c\u7528\u5c31\u662f\u6839\u636e\u6700\u7ec8\u7684\u7d22\u5f15\u503c\u627e\u5230\u5bf9\u5e94\u7684\u53ef\u5b66\u4e60\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u3002relative_position_index\u7684\u6570\u503c\u8303\u56f4(0~8)\uff0c\u5373\u00a0(2M\u22121)\u2217(2M\u22121)\u00a0,\u6240\u4ee5\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\uff08relative position bias table\uff09\u53ef\u4ee5\u7531\u4e00\u4e2a3*3\u7684\u77e9\u9635\u8868\u793a\uff0c\u5982\u56fe7\u6240\u793a\uff1a\u8fd9\u6837\u5c31\u6839\u636eindex\u5bf9\u5e94\u4f4d\u7f6e\u7684\u7d22\u5f15\u627e\u5230table\u5bf9\u5e94\u4f4d\u7f6e\u7684\u503c\u4f5c\u4e3a\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u3002<\/p>\n\n\n\n<div class=\"wp-block-image\"><figure class=\"aligncenter is-resized\"><img loading=\"lazy\" src=\"https:\/\/pic4.zhimg.com\/v2-a9d97c2d2a3e76beff0f83acc5e286e7_r.jpg\" alt=\"\" width=\"292\" height=\"251\"\/><figcaption>\u56fe7 \u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801<\/figcaption><\/figure><\/div>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"480\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-11-1024x480.png\" alt=\"\" class=\"wp-image-8806\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-11-1024x480.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-11-300x141.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-11-768x360.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-11.png 1027w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u56fe7\u4e2d\u76840-8\u4e3a\u7d22\u5f15\u503c\uff0c\u6bcf\u4e2a\u7d22\u5f15\u503c\u90fd\u5bf9\u5e94\u4e86\u00a0M2\u00a0\u7ef4\u53ef\u5b66\u4e60\u6570\u636e<strong>(\u6bcf\u4e2atoken\u90fd\u8981\u8ba1\u7b97\u00a0M2\u00a0\u4e2aQK\u503c\uff0c\u6bcf\u4e2aQK\u503c\u90fd\u8981\u52a0\u4e0a\u5bf9\u5e94\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801)<\/strong><\/p>\n\n\n\n<p>\u7ee7\u7eed\u4ee5\u56fe6\u4e2d&nbsp;M=2&nbsp;\u7684\u7a97\u53e3\u4e3a\u4f8b\uff0c\u5f53\u8ba1\u7b97\u4f4d\u7f6e1\u5bf9\u5e94\u7684&nbsp;M2&nbsp;\u4e2aQK\u503c\u65f6\uff0c\u5e94\u7528\u7684relative_position_index = [ 4, 5, 7, 8]&nbsp;(M2)\u4e2a \uff0c\u5bf9\u5e94\u7684\u6570\u636e\u5c31\u662f\u56fe7\u4e2d\u4f4d\u7f6e\u7d22\u5f154,5,7,8\u4f4d\u7f6e\u5bf9\u5e94\u7684&nbsp;M2&nbsp;\u7ef4\u6570\u636e\uff0c\u5373relative_position.shape =&nbsp;(M2\u2217M2)<\/p>\n\n\n\n<p>\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u5728\u6e90\u7801WindowAttention\u4e2d\u5e94\u7528\uff0c\u4e86\u89e3\u539f\u7406\u4e4b\u540e\u5c31\u5f88\u5bb9\u6613\u80fd\u591f\u8bfb\u61c2\u7a0b\u5e8f\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class WindowAttention(nn.Module):\n    r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n    It supports both of shifted and non-shifted window.\n\n    Args:\n        dim (int): Number of input channels.\n        window_size (tuple&#91;int]): The height and width of the window.\n        num_heads (int): Number of attention heads.\n        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n        proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n    \"\"\"\n\n    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n\n        super().__init__()\n        self.dim = dim <em># \u8f93\u5165\u901a\u9053\u7684\u6570\u91cf<\/em>\n        self.window_size = window_size  <em># Wh, Ww<\/em>\n        self.num_heads = num_heads\n        head_dim = dim \/\/ num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n\n        <em># define a parameter table of relative position bias<\/em>\n        self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size&#91;0] - 1) * (2 * window_size&#91;1] - 1), num_heads))  <em># 2*Wh-1 * 2*Ww-1, nH<\/em>  \u521d\u59cb\u5316\u8868\n\n        <em># get pair-wise relative position index for each token inside the window<\/em>\n        coords_h = torch.arange(self.window_size&#91;0]) <em># coords_h = tensor(&#91;0,1,2,...,self.window_size&#91;0]-1])  \u7ef4\u5ea6=Wh<\/em>\n        coords_w = torch.arange(self.window_size&#91;1]) <em># coords_w = tensor(&#91;0,1,2,...,self.window_size&#91;1]-1])  \u7ef4\u5ea6=Ww<\/em>\n\n        coords = torch.stack(torch.meshgrid(&#91;coords_h, coords_w]))  <em># 2, Wh, Ww<\/em>\n        coords_flatten = torch.flatten(coords, 1)  <em># 2, Wh*Ww<\/em>\n\n\n        relative_coords = coords_flatten&#91;:, :, None] - coords_flatten&#91;:, None, :]  <em># 2, Wh*Ww, Wh*Ww<\/em>\n        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  <em># Wh*Ww, Wh*Ww, 2<\/em>\n        relative_coords&#91;:, :, 0] += self.window_size&#91;0] - 1  <em># shift to start from 0<\/em>\n        relative_coords&#91;:, :, 1] += self.window_size&#91;1] - 1\n\n        '''\n        \u540e\u9762\u6211\u4eec\u9700\u8981\u5c06\u5176\u5c55\u5f00\u6210\u4e00\u7ef4\u504f\u79fb\u91cf\u3002\u800c\u5bf9\u4e8e(2,1)\u548c(1,2)\u8fd9\u4e24\u4e2a\u5750\u6807\uff0c\u5728\u4e8c\u7ef4\u4e0a\u662f\u4e0d\u540c\u7684\uff0c\u4f46\u662f\u901a\u8fc7\u5c06x\\y\u5750\u6807\u76f8\u52a0\u8f6c\u6362\u4e3a\u4e00\u7ef4\u504f\u79fb\u7684\u65f6\u5019\n        \u4ed6\u4eec\u7684\u504f\u79fb\u91cf\u662f\u76f8\u7b49\u7684\uff0c\u6240\u4ee5\u9700\u8981\u5bf9\u5176\u505a\u4e58\u6cd5\u64cd\u4f5c\uff0c\u8fdb\u884c\u533a\u5206\n        '''\n\n        relative_coords&#91;:, :, 0] *= 2 * self.window_size&#91;1] - 1\n        <em># \u8ba1\u7b97\u5f97\u5230\u76f8\u5bf9\u4f4d\u7f6e\u7d22\u5f15<\/em>\n        <em># relative_position_index.shape = (M2, M2) \u610f\u601d\u662f\u4e00\u5171\u6709\u8fd9\u4e48\u591a\u4e2a\u4f4d\u7f6e<\/em>\n        relative_position_index = relative_coords.sum(-1)  <em># Wh*Ww, Wh*Ww <\/em>\n\n        '''\n        relative_position_index\u6ce8\u518c\u4e3a\u4e00\u4e2a\u4e0d\u53c2\u4e0e\u7f51\u7edc\u5b66\u4e60\u7684\u53d8\u91cf\n        '''\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n        '''\n        \u4f7f\u7528\u4ece\u622a\u65ad\u6b63\u6001\u5206\u5e03\u4e2d\u63d0\u53d6\u7684\u503c\u586b\u5145\u8f93\u5165\u5f20\u91cf\n        self.relative_position_bias_table \u662f\u51680\u5f20\u91cf\uff0c\u901a\u8fc7trunc_normal_ \u8fdb\u884c\u6570\u503c\u586b\u5145\n        '''\n        trunc_normal_(self.relative_position_bias_table, std=.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask=None):\n        \"\"\"\n        Args:\n            x: input features with shape of (num_windows*B, N, C)\n            N: number of all patches in the window\n            C: \u8f93\u5165\u901a\u8fc7\u7ebf\u6027\u5c42\u8f6c\u5316\u5f97\u5230\u7684\u7ef4\u5ea6C\n            mask: (0\/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n        \"\"\"\n        B_, N, C = x.shape\n        '''\n        x.shape = (num_windows*B, N, C)\n        self.qkv(x).shape = (num_windows*B, N, 3C)\n        self.qkv(x).reshape(B_, N, 3, self.num_heads, C \/\/ self.num_heads).shape = (num_windows*B, N, 3, num_heads, C\/\/num_heads)\n        self.qkv(x).reshape(B_, N, 3, self.num_heads, C \/\/ self.num_heads).permute(2, 0, 3, 1, 4).shape = (3, num_windows*B, num_heads, N, C\/\/num_heads)\n        '''\n        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C \/\/ self.num_heads).permute(2, 0, 3, 1, 4)\n        '''\n        q.shape = k.shape = v.shape = (num_windows*B, num_heads, N, C\/\/num_heads)\n        N = M2 \u4ee3\u8868patches\u7684\u6570\u91cf\n        C\/\/num_heads\u4ee3\u8868Q,K,V\u7684\u7ef4\u6570\n        '''\n        q, k, v = qkv&#91;0], qkv&#91;1], qkv&#91;2]  <em># make torchscript happy (cannot use tensor as tuple)<\/em>\n\n        <em># q\u4e58\u4e0a\u4e00\u4e2a\u653e\u7f29\u7cfb\u6570\uff0c\u5bf9\u5e94\u516c\u5f0f\u4e2d\u7684sqrt(d)<\/em>\n        q = q * self.scale\n\n        <em># attn.shape = (num_windows*B, num_heads, N, N)  N = M2 \u4ee3\u8868patches\u7684\u6570\u91cf<\/em>\n        attn = (q @ k.transpose(-2, -1))\n\n        '''\n        self.relative_position_bias_table.shape = (2*Wh-1 * 2*Ww-1, nH)\n        self.relative_position_index.shape = (Wh*Ww, Wh*Ww)\n        self.relative_position_index\u77e9\u9635\u4e2d\u7684\u6240\u6709\u503c\u90fd\u662f\u4eceself.relative_position_bias_table\u4e2d\u53d6\u7684\n        self.relative_position_index\u662f\u8ba1\u7b97\u51fa\u6765\u4e0d\u53ef\u5b66\u4e60\u7684\u91cf\n        '''\n        relative_position_bias = self.relative_position_bias_table&#91;self.relative_position_index.view(-1)].view(\n            self.window_size&#91;0] * self.window_size&#91;1], self.window_size&#91;0] * self.window_size&#91;1], -1)  <em># Wh*Ww,Wh*Ww,nH<\/em>\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  <em># nH, Wh*Ww, Wh*Ww<\/em>\n\n        '''\n        attn.shape = (num_windows*B, num_heads, M2, M2)  N = M2 \u4ee3\u8868patches\u7684\u6570\u91cf\n        .unsqueeze(0)\uff1a\u6269\u5f20\u7ef4\u5ea6\uff0c\u57280\u5bf9\u5e94\u7684\u4f4d\u7f6e\u63d2\u5165\u7ef4\u5ea61\n        relative_position_bias.unsqueeze(0).shape = (1, num_heads, M2, M2)\n        num_windows*B \u901a\u8fc7\u5e7f\u64ad\u673a\u5236\u4f20\u64ad\uff0crelative_position_bias.unsqueeze(0).shape = (1, nH, M2, M2) \u7684\u7ef4\u5ea61\u4f1abroadcast\u5230\u6570\u91cfnum_windows*B\n        \u8868\u793a\u6240\u6709batch\u901a\u7528\u4e00\u4e2a\u7d22\u5f15\u77e9\u9635\u548c\u76f8\u5bf9\u4f4d\u7f6e\u77e9\u9635\n        '''\n        attn = attn + relative_position_bias.unsqueeze(0)\n\n        <em># mask.shape = (num_windows, M2, M2)<\/em>\n        <em># attn.shape = (num_windows*B, num_heads, M2, M2)<\/em>\n        if mask is not None:\n            nW = mask.shape&#91;0]\n            <em># attn.view(B_ \/\/ nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) \u7b2c\u4e00\u4e2aM2\u4ee3\u8868\u6709M2\u4e2atoken\uff0c\u7b2c\u4e8c\u4e2aM2\u4ee3\u8868\u6bcf\u4e2atoken\u8981\u8ba1\u7b97M2\u6b21QKT\u7684\u503c<\/em>\n            <em># mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) \u7b2c\u4e00\u4e2aM2\u4ee3\u8868\u6709M2\u4e2atoken\uff0c\u7b2c\u4e8c\u4e2aM2\u4ee3\u8868\u6bcf\u4e2atoken\u8981\u8ba1\u7b97M2\u6b21QKT\u7684\u503c<\/em>\n            <em># broadcast\u76f8\u52a0<\/em>\n            attn = attn.view(B_ \/\/ nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            <em># attn.shape = (B, num_windows, num_heads, M2, M2)<\/em>\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn)\n\n        '''\n        v.shape = (num_windows*B, num_heads, M2, C\/\/num_heads)  N=M2 \u4ee3\u8868patches\u7684\u6570\u91cf, C\/\/num_heads\u4ee3\u8868\u8f93\u5165\u7684\u7ef4\u5ea6\n        attn.shape = (num_windows*B, num_heads, M2, M2)\n        attn@v .shape = (num_windows*B, num_heads, M2, C\/\/num_heads)\n        '''\n        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)   <em># B_:num_windows*B  N:M2  C=num_heads*C\/\/num_heads<\/em>\n\n        <em>#   self.proj = nn.Linear(dim, dim)  dim = C<\/em>\n        <em>#   self.proj_drop = nn.Dropout(proj_drop)<\/em>\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x  <em># x.shape = (num_windows*B, N, C)  N:\u7a97\u53e3\u4e2d\u6240\u6709patches\u7684\u6570\u91cf<\/em>\n\n    def extra_repr(self) -> str:\n        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n\n    def flops(self, N):\n        <em># calculate flops for 1 window with token length of N<\/em>\n        flops = 0\n        <em># qkv = self.qkv(x)<\/em>\n        flops += N * self.dim * 3 * self.dim\n        <em># attn = (q @ k.transpose(-2, -1))<\/em>\n        flops += self.num_heads * N * (self.dim \/\/ self.num_heads) * N\n        <em>#  x = (attn @ v)<\/em>\n        flops += self.num_heads * N * N * (self.dim \/\/ self.num_heads)\n        <em># x = self.proj(x)<\/em>\n        flops += N * self.dim * self.dim\n        return flops<\/code><\/pre>\n\n\n\n<p>\u5728\u4e0a\u8ff0\u7a0b\u5e8f\u4e2d\u6709\u4e00\u6bb5mask\u76f8\u5173\u7a0b\u5e8f\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>if mask is not None:\n            nW = mask.shape&#91;0]\n            <em># attn.view(B_ \/\/ nW, nW, self.num_heads, N, N).shape = (B, num_windows, num_heads, M2, M2) \u7b2c\u4e00\u4e2aM2\u4ee3\u8868\u6709M2\u4e2atoken\uff0c\u7b2c\u4e8c\u4e2aM2\u4ee3\u8868\u6bcf\u4e2atoken\u8981\u8ba1\u7b97M2\u6b21QKT\u7684\u503c<\/em>\n            <em># mask.unsqueeze(1).unsqueeze(0).shape =                (1, num_windows, 1,         M2, M2) \u7b2c\u4e00\u4e2aM2\u4ee3\u8868\u6709M2\u4e2atoken\uff0c\u7b2c\u4e8c\u4e2aM2\u4ee3\u8868\u6bcf\u4e2atoken\u8981\u8ba1\u7b97M2\u6b21QKT\u7684\u503c<\/em>\n            <em># broadcast\u76f8\u52a0<\/em>\n            attn = attn.view(B_ \/\/ nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n            <em># attn.shape = (B, num_windows, num_heads, M2, M2)<\/em>\n            attn = attn.view(-1, self.num_heads, N, N)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)<\/code><\/pre>\n\n\n\n<p>\u8fd9\u4e2a\u90e8\u5206\u5bf9\u5e94\u7684\u662fSwin Transformer Block \u4e2d\u7684SW-MSA<\/p>\n\n\n\n<ul><li>\u8f93\u5165&nbsp;x:82B\u00d772\u00d796&nbsp;\u3002<\/li><li>\u4ea7\u751f&nbsp;QKV&nbsp;\uff0c\u8c03\u7528\u7ebf\u6027\u5c42\u540e\uff0c\u5f97\u5230&nbsp;82B\u00d772\u00d7(96\u00d73)&nbsp;\uff0c\u62c6\u5206\u7ed9\u4e0d\u540c\u7684head\uff0c\u5f97\u5230&nbsp;82B\u00d772\u00d73\u00d73\u00d732&nbsp;\uff0c\u7b2c\u4e00\u4e2a3\u662f&nbsp;QKV&nbsp;\u76843\uff0c\u7b2c\u4e8c\u4e2a3\u662f3\u4e2ahead\u3002\u518dpermute\u6210&nbsp;3\u00d782B\u00d73\u00d772\u00d732&nbsp;\uff0c\u518d\u62c6\u89e3\u6210&nbsp;q,k,v&nbsp;\uff0c\u6bcf\u4e2a\u90fd\u662f&nbsp;82B\u00d73\u00d772\u00d732&nbsp;\u3002\u8868\u793a\u6240\u6709\u56fe\u7247\u7684\u6bcf\u4e2a\u56fe\u724764\u4e2awindow\uff0c\u6bcf\u4e2awindow\u5bf9\u5e94\u52303\u4e2a\u4e0d\u540c\u7684head\uff0c\u90fd\u6709\u4e00\u595749\u4e2apatch\u300132\u7ef4\u7684\u7279\u5f81\u3002<\/li><li>q&nbsp;\u5f52\u4e00\u5316<\/li><li>qk&nbsp;\u77e9\u9635\u76f8\u4e58\u6c42\u7279\u5f81\u5185\u79ef\uff0c\u5f97\u5230&nbsp;attn:82B\u00d73\u00d772\u00d772<\/li><li>\u5f97\u5230\u76f8\u5bf9\u4f4d\u7f6e\u7684\u7f16\u7801\u4fe1\u606f<code>relative_position_bias<\/code>\uff1a<ul><li>\u4ee3\u7801\u5982\u4e0b\uff1a<\/li><\/ul><\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>self.relative_position_bias_table = nn.Parameter(\n            torch.zeros((2 * window_size&#91;0] - 1) * (2 * window_size&#91;1] - 1), num_heads))  <em># 2*Wh-1 * 2*Ww-1, nH<\/em>\n\n<em># get pair-wise relative position index for each token inside the window<\/em>\ncoords_h = torch.arange(self.window_size&#91;0])\ncoords_w = torch.arange(self.window_size&#91;1])\ncoords = torch.stack(torch.meshgrid(&#91;coords_h, coords_w]))  <em># 2, Wh, Ww<\/em>\ncoords_flatten = torch.flatten(coords, 1)  <em># 2, Wh*Ww<\/em>\nrelative_coords = coords_flatten&#91;:, :, None] - coords_flatten&#91;:, None, :]  <em># 2, Wh*Ww, Wh*Ww<\/em>\nrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  <em># Wh*Ww, Wh*Ww, 2<\/em>\nrelative_coords&#91;:, :, 0] += self.window_size&#91;0] - 1  <em># shift to start from 0<\/em>\nrelative_coords&#91;:, :, 1] += self.window_size&#91;1] - 1\nrelative_coords&#91;:, :, 0] *= 2 * self.window_size&#91;1] - 1\nrelative_position_index = relative_coords.sum(-1)  <em># Wh*Ww, Wh*Ww<\/em>\nself.register_buffer(\"relative_position_index\", relative_position_index)<\/code><\/pre>\n\n\n\n<ul><li>\u8fd9\u91cc\u4ee5<code>window_size=3<\/code>\u4e3a\u4f8b\uff0c\u89e3\u91ca\u4ee5\u4e0b\u8fc7\u7a0b\uff1a\u9996\u5148\u751f\u6210&nbsp;coords:2\u00d73\u00d73&nbsp;\uff0c\u5c31\u662f\u5728\u4e00\u4e2a&nbsp;3\u00d73&nbsp;\u7684\u7a97\u53e3\u5185\uff0c\u6bcf\u4e2a\u4f4d\u7f6e\u7684&nbsp;y,x&nbsp;\u5750\u6807\uff0c\u800c<code>relative_coords<\/code>\u4e3a&nbsp;2\u00d79\u00d79&nbsp;\uff0c\u5c31\u662f9\u4e2a\u70b9\u4e2d\uff0c\u6bcf\u4e2a\u70b9\u7684&nbsp;y&nbsp;\u6216&nbsp;x&nbsp;\u4e0e\u5176\u4ed6\u6240\u6709\u70b9\u7684\u5dee\u503c\uff0c\u6bd4\u5982&nbsp;[0][3][1]&nbsp;\u8868\u793a3\u53f7\u70b9\uff08\u7b2c\u4e8c\u884c\u7b2c\u4e00\u4e2a\u70b9\uff09\u4e0e1\u53f7\u70b9\uff08\u7b2c\u4e00\u884c\u7b2c\u4e8c\u4e2a\u70b9\uff09\u7684&nbsp;y&nbsp;\u5750\u6807\u7684\u5dee\u503c\u3002\u7136\u540e\u53d8\u5f62\uff0c\u5e76\u8ba9\u4e24\u4e2a\u5750\u6807\u5206\u522b\u52a0\u4e0a&nbsp;3\u22121=2&nbsp;\uff0c\u662f\u56e0\u4e3a\u8fd9\u4e9b\u5750\u6807\u503c\u8303\u56f4&nbsp;[0,2]&nbsp;\uff0c\u56e0\u6b64\u5dee\u503c\u7684\u6700\u5c0f\u503c\u4e3a-2\uff0c\u52a0\u4e0a2\u540e\u4ece0\u5f00\u59cb\u3002\u6700\u540e\u8ba9&nbsp;y&nbsp;\u5750\u6807\u4e58\u4e0a&nbsp;2\u00d73\u22121=5&nbsp;\uff0c\u5e94\u8be5\u662f\u4e00\u4e2atrick\uff0c\u8c03\u6574\u5dee\u503c\u8303\u56f4\u3002\u6700\u540e\u5c06\u4e24\u4e2a\u7ef4\u5ea6\u7684\u5dee\u503c\u76f8\u52a0\uff0c\u5f97\u5230<code>relative_position_index<\/code>\uff0c&nbsp;32\u00d732&nbsp;\uff0c\u4e3a9\u4e2a\u70b9\u4e4b\u95f4\u4e24\u4e24\u4e4b\u95f4\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u503c\uff0c\u6700\u540e\u7528\u6765\u5230<code>self.relative_position_bias_table<\/code>\u4e2d\u5bfb\u5740\uff0c\u6ce8\u610f\u76f8\u5bf9\u4f4d\u7f6e\u7684\u6700\u5927\u503c\u4e3a&nbsp;(2M\u22122)(2M\u22121)&nbsp;\uff0c\u800c\u8fd9\u4e2atable\u6700\u591a\u6709&nbsp;(2M\u22121)(2M\u22121)&nbsp;\u884c\uff0c\u56e0\u6b64\u4fdd\u8bc1\u53ef\u4ee5\u5bfb\u5740\uff0c\u5f97\u5230\u4e86\u4e00\u7ec4\u7ed9\u591a\u4e2ahead\u4f7f\u7528\u7684\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\u4fe1\u606f\uff0c\u8fd9\u4e2atable\u662f\u53ef\u8bad\u7ec3\u7684\u53c2\u6570\u3002<\/li><li>\u56de\u5230\u4ee3\u7801\u4e2d\uff0c\u5f97\u5230\u7684<code>relative_position_bias<\/code>\u4e3a&nbsp;3\u00d772\u00d772<\/li><li>\u5c06\u5176\u52a0\u5230<code>attn<\/code>\u4e0a\uff0c\u6700\u540e\u4e00\u4e2a\u7ef4\u5ea6softmax\uff0cdropout<\/li><li>\u4e0e&nbsp;v&nbsp;\u77e9\u9635\u76f8\u4e58\uff0c\u5e76\u8f6c\u7f6e\uff0c\u5408\u5e76\u591a\u4e2a\u5934\u7684\u4fe1\u606f\uff0c\u5f97\u5230&nbsp;82B\u00d772\u00d796<\/li><li>\u7ecf\u8fc7\u4e00\u5c42\u7ebf\u6027\u5c42\uff0cdropout\uff0c\u8fd4\u56de<\/li><li>\u8fd4\u56de\u8d4b\u503c\u7ed9<code>attn_windows<\/code>\uff0c\u53d8\u5f62\u4e3a&nbsp;82B\u00d77\u00d77\u00d796<\/li><li>\u8c03\u7528<code>window_reverse<\/code>\uff0c\u6253\u56de\u539f\u72b6\uff1a&nbsp;B\u00d756\u00d756\u00d796<\/li><li>\u8fd4\u56de\u7ed9&nbsp;x&nbsp;\uff0c\u7ecf\u8fc7FFN\uff1a\u5148\u52a0\u4e0a\u539f\u6765\u7684\u8f93\u5165&nbsp;x&nbsp;\u4f5c\u4e3aresidue\u7ed3\u6784\uff0c\u6ce8\u610f\u8fd9\u91cc\u7528\u5230<a href=\"https:\/\/github.com\/rwightman\/pytorch-image-models\/blob\/master\/timm\/models\/layers\/drop.py\" target=\"_blank\" rel=\"noreferrer noopener\">timm<\/a>\u7684<code>DropPath<\/code>\uff0c\u5e76\u4e14drop\u7684\u6982\u7387\u662f\u6574\u4e2a\u7f51\u7edc\u7ed3\u6784\u7ebf\u6027\u589e\u957f\u7684\u3002\u7136\u540e\u518d\u52a0\u4e0a\u4e24\u5c42mlp\u7684\u7ed3\u679c\u3002<\/li><li>\u8fd4\u56de\u7ed3\u679c&nbsp;x&nbsp;\u3002<\/li><\/ul>\n\n\n\n<p>\u8fd9\u6837\uff0c\u6574\u4e2a\u8fc7\u7a0b\u5c31\u5b8c\u6210\u4e86\uff0c\u5269\u4e0b\u7684\u5c31\u662fSW-MSA\u7684\u4e00\u4e9b\u4e0d\u540c\u7684\u64cd\u4f5c\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" width=\"598\" height=\"215\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-3.png\" alt=\"\" class=\"wp-image-8734\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-3.png 598w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-3-300x108.png 300w\" sizes=\"(max-width: 598px) 100vw, 598px\" \/><\/figure>\n\n\n\n<ol><li>\u9996\u5148\u5c06windows\u8fdb\u884c\u534a\u4e2a\u7a97\u53e3\u7684\u5faa\u73af\u79fb\u4f4d\uff0c\u4e0a\u56fe\u4e2d\u76841\uff0c 2\u6b65\u9aa4\uff0c\u4f7f\u7528torch.roll\u5b9e\u73b0\u3002<\/li><li>\u5728\u76f8\u540c\u7684\u7a97\u53e3\u4e2d\u8ba1\u7b97\u81ea\u6ce8\u610f\u529b\uff0c\u8ba1\u7b97\u7ed3\u679c\u5982\u4e0b\u53f3\u56fe\u6240\u793a\uff0cwindow0\u7684\u7ed3\u6784\u4fdd\u5b58\uff0c\u4f46\u662f\u9488\u5bf9window2\u7684\u8ba1\u7b97\uff0c\u5176\u4e2d3\u4e0e3\u30016\u4e0e6\u7684\u8ba1\u7b97\u751f\u6210\u4e86attn mask \u4e2dwindow2\u4e2d\u7684\u9ec4\u8272\u533a\u57df\uff0c\u9488\u5bf9windows2\u4e2d3\u4e0e6\u30016\u4e0e3\u4e4b\u95f4\u4e0d\u5e94\u8be5\u8ba1\u7b97\u81ea\u6ce8\u610f\u529b\uff08attn mask\u4e2dwindow2\u7684\u84dd\u8272\u533a\u57df\uff09\uff0c\u5c06\u84dd\u8272\u533a\u57dfmask\u8d4b\u503c\u4e3a-100\uff0c\u7ecf\u8fc7softmax\u4e4b\u540e\uff0c\u8d77\u4f5c\u7528\u53ef\u4ee5\u5ffd\u7565\u4e0d\u8ba1\u3002\u540c\u7406window1\u4e0ewindow3\u7684\u8ba1\u7b97\u4e00\u81f4\u3002<\/li><li>\u6700\u540e\u518d\u8fdb\u884c\u5faa\u73af\u79fb\u4f4d\uff0c\u6062\u590d\u539f\u6765\u7684\u4f4d\u7f6e\u3002<\/li><\/ol>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"451\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-4-1024x451.png\" alt=\"\" class=\"wp-image-8736\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-4-1024x451.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-4-300x132.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-4-768x338.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-4.png 1296w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<p>\u539f\u8bba\u6587\u56fe\u4e2d\u7684Stage\u548c\u7a0b\u5e8f\u4e2d\u7684\u4e00\u4e2aStage\u4e0d\u540c\uff1a<\/p>\n\n\n\n<p>\u7a0b\u5e8f\u4e2d\u7684BasicLayer\u4e3a\u4e00\u4e2aStage\uff0c\u5728BasicLayer\u4e2d\u8c03\u7528\u4e86\u4e0a\u9762\u8bb2\u5230\u7684SwinTransformerBlock\u548cPatchMerging\u6a21\u5757:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BasicLayer(nn.Module):  <em># \u8bba\u6587\u56fe\u4e2d\u6bcf\u4e2astage\u91cc\u5bf9\u5e94\u7684\u82e5\u5e72\u4e2aSwinTransformerBlock<\/em>\n    \"\"\" A basic Swin Transformer layer for one stage.\n\n    Args:\n        dim (int): Number of input channels.\n        input_resolution (tuple&#91;int]): Input resolution.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        window_size (int): Local window size.\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n        drop (float, optional): Dropout rate. Default: 0.0\n        attn_drop (float, optional): Attention dropout rate. Default: 0.0\n        drop_path (float | tuple&#91;float], optional): Stochastic depth rate. Default: 0.0\n        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n    \"\"\"\n\n    def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):\n\n        super().__init__()\n        self.dim = dim\n        self.input_resolution = input_resolution\n        self.depth = depth <em># swin_transformer blocks\u7684\u4e2a\u6570<\/em>\n        self.use_checkpoint = use_checkpoint\n\n        <em># build blocks  \u4ece0\u5f00\u59cb\u7684\u5076\u6570\u4f4d\u7f6e\u7684SwinTransformerBlock\u8ba1\u7b97\u7684\u662fW-MSA,\u5947\u6570\u4f4d\u7f6e\u7684Block\u8ba1\u7b97\u7684\u662fSW-MSA\uff0c\u4e14shift_size = window_size\/\/2<\/em>\n        self.blocks = nn.ModuleList(&#91;\n            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n                                 num_heads=num_heads, window_size=window_size,\n                                 shift_size=0 if (i % 2 == 0) else window_size \/\/ 2,\n                                 mlp_ratio=mlp_ratio,\n                                 qkv_bias=qkv_bias, qk_scale=qk_scale,\n                                 drop=drop, attn_drop=attn_drop,\n                                 drop_path=drop_path&#91;i] if isinstance(drop_path, list) else drop_path,\n                                 norm_layer=norm_layer)\n            for i in range(depth)])\n\n        <em># patch merging layer<\/em>\n        if downsample is not None:\n            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n        else:\n            self.downsample = None\n\n    def forward(self, x):\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint.checkpoint(blk, x)\n            else:\n                x = blk(x)  <em># blk = SwinTransformerBlock<\/em>\n        if self.downsample is not None:\n            x = self.downsample(x)\n        return x\n\n    def extra_repr(self) -&gt; str:\n        return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n\n    def flops(self):\n        flops = 0\n        for blk in self.blocks:\n            flops += blk.flops()\n        if self.downsample is not None:\n            flops += self.downsample.flops()\n        return flops<\/code><\/pre>\n\n\n\n<p><strong><em>Part 3 : \u4e0d\u540c\u89c6\u89c9\u4efb\u52a1\u8f93\u51fa<\/em><\/strong><\/p>\n\n\n\n<p>\u7a0b\u5e8f\u4e2d\u5bf9\u5e94\u7684\u662f\u56fe\u7247\u5206\u7c7b\u4efb\u52a1\uff0c\u7ecf\u8fc7Part 2 \u4e4b\u540e\u7684\u6570\u636e\u901a\u8fc7 norm\/avgpool\/flatten:<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code> x = self.norm(x)  <em># B L C<\/em>\n x = self.avgpool(x.transpose(1, 2))  <em># B C 1<\/em>\n x = torch.flatten(x, 1) <em># B C<\/em><\/code><\/pre>\n\n\n\n<p>\u4e4b\u540e\u901a\u8fc7nn.Linear\u5c06\u7279\u5f81\u8f6c\u5316\u4e3a\u5bf9\u5e94\u7684\u7c7b\u522b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>self.head = nn.Linear(self.num_features, num_classes) if num_classes &gt; 0 else nn.Identity()<\/code><\/pre>\n\n\n\n<p>\u5e94\u7528\u4e8e\u5176\u4ed6\u4e0d\u540c\u7684\u89c6\u89c9\u4efb\u52a1\u65f6\uff0c\u53ea\u9700\u8981\u5c06\u8f93\u51fa\u8fdb\u884c\u7279\u5b9a\u7684\u4fee\u6539\u5373\u53ef\u3002<\/p>\n\n\n\n<h2><strong><em>\u5b8c\u6574\u7684SwinTransformer\u7a0b\u5e8f\u5982\u4e0b\uff1a<\/em><\/strong><\/h2>\n\n\n\n<pre class=\"wp-block-code\"><code>class SwinTransformer(nn.Module):\n    r\"\"\" Swin Transformer\n        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -\n          https:&#47;&#47;arxiv.org\/pdf\/2103.14030\n\n    Args:\n        img_size (int | tuple(int)): Input image size. Default 224\n        patch_size (int | tuple(int)): Patch size. Default: 4\n        in_chans (int): Number of input image channels. Default: 3\n        num_classes (int): Number of classes for classification head. Default: 1000\n        embed_dim (int): Patch embedding dimension. Default: 96\n        depths (tuple(int)): Depth of each Swin Transformer layer.\n        num_heads (tuple(int)): Number of attention heads in different layers.\n        window_size (int): Window size. Default: 7\n        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n        drop_rate (float): Dropout rate. Default: 0\n        attn_drop_rate (float): Attention dropout rate. Default: 0\n        drop_path_rate (float): Stochastic depth rate. Default: 0.1\n        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n        patch_norm (bool): If True, add normalization after patch embedding. Default: True\n        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,\n                 embed_dim=96, depths=&#91;2, 2, 6, 2], num_heads=&#91;3, 6, 12, 24],\n                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n                 use_checkpoint=False, **kwargs):\n        super().__init__()\n\n        self.num_classes = num_classes <em># 1000<\/em>\n        self.num_layers = len(depths) <em># &#91;2, 2, 6, 2]  Swin_T \u7684\u914d\u7f6e<\/em>\n        self.embed_dim = embed_dim <em># 96<\/em>\n        self.ape = ape <em># False<\/em>\n        self.patch_norm = patch_norm <em># True<\/em>\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))  <em># 96*2^3<\/em>\n        self.mlp_ratio = mlp_ratio <em># 4<\/em>\n\n        <em># split image into non-overlapping patches<\/em>\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None)\n        num_patches = self.patch_embed.num_patches\n        patches_resolution = self.patch_embed.patches_resolution\n        self.patches_resolution = patches_resolution\n\n        <em># absolute position embedding<\/em>\n        if self.ape:\n            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))\n            trunc_normal_(self.absolute_pos_embed, std=.02)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        <em># stochastic depth<\/em>\n        dpr = &#91;x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  <em># stochastic depth decay rule<\/em>\n\n        <em># build layers<\/em>\n        self.layers = nn.ModuleList()\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n                               input_resolution=(patches_resolution&#91;0] \/\/ (2 ** i_layer),\n                                                 patches_resolution&#91;1] \/\/ (2 ** i_layer)),\n                               depth=depths&#91;i_layer],\n                               num_heads=num_heads&#91;i_layer],\n                               window_size=window_size,\n                               mlp_ratio=self.mlp_ratio,\n                               qkv_bias=qkv_bias, qk_scale=qk_scale,\n                               drop=drop_rate, attn_drop=attn_drop_rate,\n                               drop_path=dpr&#91;sum(depths&#91;:i_layer]):sum(depths&#91;:i_layer + 1])],\n                               norm_layer=norm_layer,\n                               downsample=PatchMerging if (i_layer &lt; self.num_layers - 1) else None,\n                               use_checkpoint=use_checkpoint)\n            self.layers.append(layer)\n\n        self.norm = norm_layer(self.num_features) <em># norm_layer = nn.LayerNorm<\/em>\n        self.avgpool = nn.AdaptiveAvgPool1d(1)\n        self.head = nn.Linear(self.num_features, num_classes) if num_classes &gt; 0 else nn.Identity()\n\n        self.apply(self._init_weights)  <em># \u4f7f\u7528self.apply \u521d\u59cb\u5316\u53c2\u6570<\/em>\n\n    def _init_weights(self, m):\n        <em># is_instance \u5224\u65ad\u5bf9\u8c61\u662f\u5426\u4e3a\u5df2\u77e5\u7c7b\u578b<\/em>\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'absolute_pos_embed'}\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'relative_position_bias_table'}\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)  <em># x.shape = (H\/\/4, W\/\/4, C)<\/em>\n        if self.ape:\n            x = x + self.absolute_pos_embed\n        x = self.pos_drop(x)  <em># self.pos_drop = nn.Dropout(p=drop_rate)<\/em>\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm(x)  <em># B L C<\/em>\n        x = self.avgpool(x.transpose(1, 2))  <em># B C 1<\/em>\n        x = torch.flatten(x, 1) <em># B C<\/em>\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)  <em># x\u662f\u8bba\u6587\u56fe\u4e2dFigure 3 a\u56fe\u4e2d\u6700\u540e\u7684\u8f93\u51fa<\/em>\n        <em>#  self.head = nn.Linear(self.num_features, num_classes) if num_classes &gt; 0 else nn.Identity()<\/em>\n        x = self.head(x) <em># x.shape = (B, num_classes)<\/em>\n        return x\n\n    def flops(self):\n        flops = 0\n        flops += self.patch_embed.flops()\n        for i, layer in enumerate(self.layers):\n            flops += layer.flops()\n        flops += self.num_features * self.patches_resolution&#91;0] * self.patches_resolution&#91;1] \/\/ (2 ** self.num_layers)\n        flops += self.num_features * self.num_classes\n        return flops<\/code><\/pre>\n\n\n\n<p class=\"has-bright-blue-color has-light-pink-background-color has-text-color has-background\">\u8865\u5145\uff1a\u6709\u5173swin\u00a0transformer\u76f8\u5bf9\u4f4d\u7f6e\u7f16\u7801\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"392\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-12-1024x392.png\" alt=\"\" class=\"wp-image-8818\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-12-1024x392.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-12-300x115.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-12-768x294.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-12.png 1071w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" width=\"993\" height=\"117\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-13.png\" alt=\"\" class=\"wp-image-8819\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-13.png 993w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-13-300x35.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-13-768x90.png 768w\" sizes=\"(max-width: 993px) 100vw, 993px\" \/><\/figure>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"235\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-14-1024x235.png\" alt=\"\" class=\"wp-image-8821\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-14-1024x235.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-14-300x69.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-14-768x176.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-14.png 1148w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n\n\n\n<figure class=\"wp-block-image size-large\"><img loading=\"lazy\" width=\"1024\" height=\"411\" src=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-15-1024x411.png\" alt=\"\" class=\"wp-image-8822\" srcset=\"http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-15-1024x411.png 1024w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-15-300x120.png 300w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-15-768x308.png 768w, http:\/\/139.9.1.231\/wp-content\/uploads\/2022\/10\/image-15.png 1094w\" sizes=\"(max-width: 1024px) 100vw, 1024px\" \/><\/figure>\n","protected":false},"excerpt":{"rendered":"<p>code\uff1ahttps:\/\/github.com\/microsoft\/Swin-Transformer \u4ee3\u7801\u8be6\u89e3 &hellip; <a href=\"http:\/\/139.9.1.231\/index.php\/2022\/10\/04\/swin-transformer-code\/\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">Swin Transformer \u4ee3\u7801\u8be6\u89e3<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":[],"categories":[21,4],"tags":[],"_links":{"self":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/8652"}],"collection":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/comments?post=8652"}],"version-history":[{"count":59,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/8652\/revisions"}],"predecessor-version":[{"id":8823,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/8652\/revisions\/8823"}],"wp:attachment":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/media?parent=8652"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/categories?post=8652"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/tags?post=8652"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}