{"id":5780,"date":"2022-08-13T21:15:22","date_gmt":"2022-08-13T13:15:22","guid":{"rendered":"http:\/\/139.9.1.231\/?p=5780"},"modified":"2022-08-13T21:15:24","modified_gmt":"2022-08-13T13:15:24","slug":"huggingface-transformers-bert","status":"publish","type":"post","link":"http:\/\/139.9.1.231\/index.php\/2022\/08\/13\/huggingface-transformers-bert\/","title":{"rendered":"HuggingFace Transformers &#8212;-BERT \u6e90\u7801"},"content":{"rendered":"\n<p>\u6458\u81ea\uff1a https:\/\/zhuanlan.zhihu.com\/p\/360988428<\/p>\n\n\n\n<p>\u4f17\u6240\u5468\u77e5\uff0cBERT\u6a21\u578b\u81ea2018\u5e74\u95ee\u4e16\u8d77\u5c31\u5404\u79cd\u5c60\u699c\uff0c\u5f00\u542f\u4e86NLP\u9886\u57df\u9884\u8bad\u7ec3+\u5fae\u8c03\u7684\u8303\u5f0f\u3002\u5230\u73b0\u5728\uff0cBERT\u7684\u76f8\u5173\u884d\u751f\u6a21\u578b\u5c42\u51fa\u4e0d\u7a77\uff08XL-Net\u3001RoBERTa\u3001ALBERT\u3001ELECTRA\u3001ERNIE\u7b49\uff09\uff0c\u8981\u7406\u89e3\u5b83\u4eec\u53ef\u4ee5\u5148\u4eceBERT\u8fd9\u4e2a\u59cb\u7956\u5165\u624b\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic4.zhimg.com\/v2-4d162cfcf61a5621ab677fb4d06fe7f3_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<p>HuggingFace\u662f\u4e00\u5bb6\u603b\u90e8\u4f4d\u4e8e\u7ebd\u7ea6\u7684\u804a\u5929\u673a\u5668\u4eba\u521d\u521b\u670d\u52a1\u5546\uff0c\u5f88\u65e9\u5c31\u6355\u6349\u5230BERT\u5927\u6f6e\u6d41\u7684\u4fe1\u53f7\u5e76\u7740\u624b\u5b9e\u73b0\u57fa\u4e8epytorch\u7684BERT\u6a21\u578b\u3002\u8fd9\u4e00\u9879\u76ee\u6700\u521d\u540d\u4e3apytorch-pretrained-bert\uff0c\u5728\u590d\u73b0\u4e86\u539f\u59cb\u6548\u679c\u7684\u540c\u65f6\uff0c\u63d0\u4f9b\u4e86\u6613\u7528\u7684\u65b9\u6cd5\u4ee5\u65b9\u4fbf\u5728\u8fd9\u4e00\u5f3a\u5927\u6a21\u578b\u7684\u57fa\u7840\u4e0a\u8fdb\u884c\u5404\u79cd\u73a9\u800d\u548c\u7814\u7a76\u3002<\/p>\n\n\n\n<p>\u968f\u7740\u4f7f\u7528\u4eba\u6570\u7684\u589e\u52a0\uff0c\u8fd9\u4e00\u9879\u76ee\u4e5f\u53d1\u5c55\u6210\u4e3a\u4e00\u4e2a\u8f83\u5927\u7684\u5f00\u6e90\u793e\u533a\uff0c\u5408\u5e76\u4e86\u5404\u79cd\u9884\u8bad\u7ec3\u8bed\u8a00\u6a21\u578b\u4ee5\u53ca\u589e\u52a0\u4e86Tensorflow\u7684\u5b9e\u73b0\uff0c\u5e76\u4e14\u57282019\u5e74\u4e0b\u534a\u5e74\u6539\u540d\u4e3aTransformers\u3002\u622a\u6b62\u5199\u6587\u7ae0\u65f6\uff082021\u5e743\u670830\u65e5\uff09\u8fd9\u4e00\u9879\u76ee\u5df2\u7ecf\u62e5\u670943k+\u7684star\uff0c\u53ef\u4ee5\u8bf4Transformers\u5df2\u7ecf\u6210\u4e3a\u4e8b\u5b9e\u4e0a\u7684NLP\u57fa\u672c\u5de5\u5177\u3002<a target=\"_blank\" href=\"https:\/\/github.com\/huggingface\/transformers\" rel=\"noreferrer noopener\">https:\/\/github.com\/huggingface\/transformers\u200bgithub.com\/huggingface\/transformers<\/a><\/p>\n\n\n\n<p>\u672c\u6587\u57fa\u4e8eTransformers\u7248\u672c4.4.2\uff082021\u5e743\u670819\u65e5\u53d1\u5e03\uff09\u9879\u76ee\u4e2d\uff0cpytorch\u7248\u7684BERT\u76f8\u5173\u4ee3\u7801\uff0c\u4ece\u4ee3\u7801\u7ed3\u6784\u3001\u5177\u4f53\u5b9e\u73b0\u4e0e\u539f\u7406\uff0c\u4ee5\u53ca\u4f7f\u7528\u7684\u89d2\u5ea6\u8fdb\u884c\u5206\u6790\uff0c\u5305\u542b\u4ee5\u4e0b\u5185\u5bb9\uff1a<\/p>\n\n\n\n<ol><li><strong>BERT Tokenization\u5206\u8bcd\u6a21\u578b\uff08BertTokenizer\uff09<\/strong><\/li><li><strong>BERT Model\u672c\u4f53\u6a21\u578b\uff08BertModel\uff09<\/strong><ol><li><strong>BertEmbeddings<\/strong><\/li><li><strong>BertEncoder<\/strong><ol><li><strong>BertLayer<\/strong><ol><li><strong>BertAttention<\/strong><ol><li><strong>BertSelfAttention<\/strong><\/li><li><strong>BertSelfOutput<\/strong><\/li><\/ol><\/li><li><strong>BertIntermediate<\/strong><\/li><li><strong>BertOutput<\/strong><\/li><\/ol><\/li><li><strong>BertPooler<\/strong><\/li><\/ol><\/li><\/ol><\/li><li><strong>BERT-based Models\u5e94\u7528\u6a21\u578b\uff08\u8bf7\u770b\u4e0b\u7bc7\uff09<\/strong><ol><li><strong>BertForPreTraining<\/strong><\/li><li><strong>BertForSequenceClassification<\/strong><\/li><li><strong>BertForMultiChoice<\/strong><\/li><li><strong>BertForTokenClassification<\/strong><\/li><li><strong>BertForQuestionAnswering<\/strong><\/li><\/ol><\/li><li><strong>BERT\u8bad\u7ec3\u4e0e\u4f18\u5316\uff08\u8bf7\u770b\u4e0b\u7bc7\uff09<\/strong><ol><li><strong>Pre-Training<\/strong><\/li><li><strong>Fine-Tuning<\/strong><ol><li><strong>AdamW<\/strong><\/li><li><strong>Warmup<\/strong><\/li><\/ol><\/li><\/ol><\/li><\/ol>\n\n\n\n<h2>1 Tokenization\uff08BertTokenizer\uff09<\/h2>\n\n\n\n<p>\u548cBERT\u6709\u5173\u7684Tokenizer\u4e3b\u8981\u5199\u5728<code>\/models\/bert\/tokenization_bert.py<\/code>\u548c<code>\/models\/bert\/tokenization_bert_fast.py<\/code>&nbsp;\u4e2d\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e24\u4efd\u4ee3\u7801\u5206\u522b\u5bf9\u5e94\u57fa\u672c\u7684<code>BertTokenizer<\/code>\uff0c\u4ee5\u53ca\u4e0d\u8fdb\u884ctoken\u5230index\u6620\u5c04\u7684<code>BertTokenizerFast<\/code>\uff0c\u8fd9\u91cc\u4e3b\u8981\u8bb2\u89e3\u7b2c\u4e00\u4e2a\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertTokenizer(PreTrainedTokenizer):\n    \"\"\"\n    Construct a BERT tokenizer. Based on WordPiece.\n\n    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.\n    Users should refer to this superclass for more information regarding those methods.\n    ...\n    \"\"\"<\/code><\/pre>\n\n\n\n<p><code>BertTokenizer<\/code>&nbsp;\u662f\u57fa\u4e8e<code>BasicTokenizer<\/code>\u548c<code>WordPieceTokenizer<\/code>&nbsp;\u7684\u5206\u8bcd\u5668\uff1a<\/p>\n\n\n\n<ul><li><code><strong>BasicTokenizer<\/strong><\/code>\u8d1f\u8d23\u5904\u7406\u7684\u7b2c\u4e00\u6b65\u2014\u2014\u6309\u6807\u70b9\u3001\u7a7a\u683c\u7b49\u5206\u5272\u53e5\u5b50\uff0c\u5e76\u5904\u7406\u662f\u5426\u7edf\u4e00\u5c0f\u5199\uff0c\u4ee5\u53ca\u6e05\u7406\u975e\u6cd5\u5b57\u7b26\u3002<ul><li>\u5bf9\u4e8e\u4e2d\u6587\u5b57\u7b26\uff0c\u901a\u8fc7\u9884\u5904\u7406\uff08\u52a0\u7a7a\u683c\uff09\u6765\u6309\u5b57\u5206\u5272\uff1b<\/li><li>\u540c\u65f6\u53ef\u4ee5\u901a\u8fc7<code>never_split<\/code>\u6307\u5b9a\u5bf9\u67d0\u4e9b\u8bcd\u4e0d\u8fdb\u884c\u5206\u5272\uff1b<\/li><li>\u8fd9\u4e00\u6b65\u662f\u53ef\u9009\u7684\uff08\u9ed8\u8ba4\u6267\u884c\uff09\u3002<\/li><\/ul><\/li><li><code><strong>WordPieceTokenizer<\/strong><\/code>\u5728\u8bcd\u7684\u57fa\u7840\u4e0a\uff0c\u8fdb\u4e00\u6b65\u5c06\u8bcd\u5206\u89e3\u4e3a<strong>\u5b50\u8bcd<\/strong>\uff08subword\uff09 \u3002<ul><li>subword\u4ecb\u4e8echar\u548cword\u4e4b\u95f4\uff0c\u65e2\u5728\u4e00\u5b9a\u7a0b\u5ea6\u4fdd\u7559\u4e86\u8bcd\u7684\u542b\u4e49\uff0c\u53c8\u80fd\u591f\u7167\u987e\u5230\u82f1\u6587\u4e2d\u5355\u590d\u6570\u3001\u65f6\u6001\u5bfc\u81f4\u7684\u8bcd\u8868\u7206\u70b8\u548c<strong>\u672a\u767b\u5f55\u8bcd<\/strong>\u7684OOV\uff08Out-Of-Vocabulary\uff09\u95ee\u9898\uff0c\u5c06\u8bcd\u6839\u4e0e\u65f6\u6001\u8bcd\u7f00\u7b49\u5206\u5272\u51fa\u6765\uff0c\u4ece\u800c\u51cf\u5c0f\u8bcd\u8868\uff0c\u4e5f\u964d\u4f4e\u4e86\u8bad\u7ec3\u96be\u5ea6\uff1b<\/li><li>\u4f8b\u5982\uff0ctokenizer\u8fd9\u4e2a\u8bcd\u5c31\u53ef\u4ee5\u62c6\u89e3\u4e3a\u201ctoken\u201d\u548c\u201c##izer\u201d\u4e24\u90e8\u5206\uff0c\u6ce8\u610f\u540e\u9762\u4e00\u4e2a\u8bcd\u7684\u201c##\u201d\u8868\u793a\u63a5\u5728\u524d\u4e00\u4e2a\u8bcd\u540e\u9762\u3002<\/li><\/ul><\/li><\/ul>\n\n\n\n<p><code>BertTokenizer<\/code>&nbsp;\u6709\u4ee5\u4e0b\u5e38\u7528\u65b9\u6cd5\uff1a<\/p>\n\n\n\n<ul><li><code><strong>from_pretrained<\/strong><\/code>\uff1a\u4ece\u5305\u542b\u8bcd\u8868\u6587\u4ef6\uff08vocab.txt\uff09\u7684\u76ee\u5f55\u4e2d\u521d\u59cb\u5316\u4e00\u4e2a\u5206\u8bcd\u5668\uff1b<\/li><li><code><strong>tokenize<\/strong><\/code>\uff1a\u5c06\u6587\u672c\uff08\u8bcd\u6216\u8005\u53e5\u5b50\uff09\u5206\u89e3\u4e3a\u5b50\u8bcd\u5217\u8868\uff1b<\/li><li><code><strong>convert_tokens_to_ids<\/strong><\/code>\uff1a\u5c06\u5b50\u8bcd\u5217\u8868\u8f6c\u5316\u4e3a\u5b50\u8bcd\u5bf9\u5e94\u4e0b\u6807\u7684\u5217\u8868\uff1b<\/li><li><code><strong>convert_ids_to_tokens<\/strong><\/code>&nbsp;\uff1a\u4e0e\u4e0a\u4e00\u4e2a\u76f8\u53cd\uff1b<\/li><li><code><strong>convert_tokens_to_string<\/strong><\/code>\uff1a\u5c06subword\u5217\u8868\u6309\u201c##\u201d\u62fc\u63a5\u56de\u8bcd\u6216\u8005\u53e5\u5b50\uff1b<\/li><li><code><strong>encode<\/strong><\/code>\uff1a\u5bf9\u4e8e\u5355\u4e2a\u53e5\u5b50\u8f93\u5165\uff0c\u5206\u89e3\u8bcd\u5e76\u52a0\u5165\u7279\u6b8a\u8bcd\u5f62\u6210\u201c[CLS], x, [SEP]\u201d\u7684\u7ed3\u6784\u5e76\u8f6c\u6362\u4e3a\u8bcd\u8868\u5bf9\u5e94\u4e0b\u6807\u7684\u5217\u8868\uff1b\u5bf9\u4e8e\u4e24\u4e2a\u53e5\u5b50\u8f93\u5165\uff08\u591a\u4e2a\u53e5\u5b50\u53ea\u53d6\u524d\u4e24\u4e2a\uff09\uff0c\u5206\u89e3\u8bcd\u5e76\u52a0\u5165\u7279\u6b8a\u8bcd\u5f62\u6210\u201c[CLS], x1, [SEP], x2, [SEP]\u201d\u7684\u7ed3\u6784\u5e76\u8f6c\u6362\u4e3a\u4e0b\u6807\u5217\u8868\uff1b<\/li><li><code><strong>decode<\/strong><\/code>\uff1a\u53ef\u4ee5\u5c06encode\u65b9\u6cd5\u7684\u8f93\u51fa\u53d8\u4e3a\u5b8c\u6574\u53e5\u5b50\u3002<\/li><\/ul>\n\n\n\n<p>\u4ee5\u53ca\uff0c\u7c7b\u81ea\u8eab\u7684\u65b9\u6cd5\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>&gt;&gt;&gt; from transformers import BertTokenizer\n&gt;&gt;&gt; bt = BertTokenizer.from_pretrained('.\/bert-base-uncased\/')\n&gt;&gt;&gt; bt('I like natural language progressing!')\n{'input_ids': &#91;101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': &#91;0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': &#91;1, 1, 1, 1, 1, 1, 1, 1]}<\/code><\/pre>\n\n\n\n<hr class=\"wp-block-separator\"\/>\n\n\n\n<h2>2 Model\uff08BertModel\uff09<\/h2>\n\n\n\n<p>\u548cBERT\u6a21\u578b\u6709\u5173\u7684\u4ee3\u7801\u4e3b\u8981\u5199\u5728<code>\/models\/bert\/modeling_bert.py<\/code>\u4e2d\uff0c\u8fd9\u4e00\u4efd\u4ee3\u7801\u6709\u4e00\u5343\u591a\u884c\uff0c\u5305\u542bBERT\u6a21\u578b\u7684\u57fa\u672c\u7ed3\u6784\u548c\u57fa\u4e8e\u5b83\u7684\u5fae\u8c03\u6a21\u578b\u7b49\u3002<\/p>\n\n\n\n<p>\u4e0b\u9762\u4eceBERT\u6a21\u578b\u672c\u4f53\u5165\u624b\u5206\u6790\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertModel(BertPreTrainedModel):\n    \"\"\"\n\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need &lt;https:\/\/arxiv.org\/abs\/1706.03762&gt;`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n\n    To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration\n    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\" <\/code><\/pre>\n\n\n\n<p>BertModel\u4e3b\u8981\u4e3atransformer encoder\u7ed3\u6784\uff0c\u5305\u542b\u4e09\u4e2a\u90e8\u5206\uff1a<\/p>\n\n\n\n<ol><li><code><strong>embeddings<\/strong><\/code>\uff0c\u5373<code>BertEmbeddings<\/code>\u7c7b\u7684\u5b9e\u4f53\uff0c\u5bf9\u5e94\u8bcd\u5d4c\u5165\uff1b<\/li><li><code><strong>encoder<\/strong><\/code>\uff0c\u5373<code>BertEncoder<\/code>\u7c7b\u7684\u5b9e\u4f53\uff1b<\/li><li><code><strong>pooler<\/strong><\/code>\uff0c \u5373<code>BertPooler<\/code>\u7c7b\u7684\u5b9e\u4f53\uff0c\u8fd9\u4e00\u90e8\u5206\u662f\u53ef\u9009\u7684\u3002<\/li><\/ol>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u6ce8\u610fBertModel\u4e5f\u53ef\u4ee5\u914d\u7f6e\u4e3aDecoder\uff0c\u4e0d\u8fc7\u4e0b\u6587\u4e2d\u4e0d\u5305\u542b\u5bf9\u8fd9\u4e00\u90e8\u5206\u7684\u8ba8\u8bba\u3002<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u4e0b\u9762\u5c06\u4ecb\u7ecdBertModel\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u5404\u4e2a\u53c2\u6570\u7684\u542b\u4e49\u4ee5\u53ca\u8fd4\u56de\u503c\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ): ...<\/code><\/pre>\n\n\n\n<ul><li><code><strong>input_ids<\/strong><\/code>\uff1a\u7ecf\u8fc7tokenizer\u5206\u8bcd\u540e\u7684subword\u5bf9\u5e94\u7684\u4e0b\u6807\u5217\u8868\uff1b<\/li><li><code><strong>attention_mask<\/strong><\/code>\uff1a\u5728self-attention\u8fc7\u7a0b\u4e2d\uff0c\u8fd9\u4e00\u5757mask\u7528\u4e8e\u6807\u8bb0subword\u6240\u5904\u53e5\u5b50\u548cpadding\u7684\u533a\u522b\uff0c\u5c06padding\u90e8\u5206\u586b\u5145\u4e3a0\uff1b<\/li><li><code><strong>token_type_ids<\/strong><\/code>\uff1a \u6807\u8bb0subword\u5f53\u524d\u6240\u5904\u53e5\u5b50\uff08\u7b2c\u4e00\u53e5\/\u7b2c\u4e8c\u53e5\/padding\uff09\uff1b<\/li><li><code><strong>position_ids<\/strong><\/code>\uff1a \u6807\u8bb0\u5f53\u524d\u8bcd\u6240\u5728\u53e5\u5b50\u7684\u4f4d\u7f6e\u4e0b\u6807\uff1b<\/li><li><code><strong>head_mask<\/strong><\/code>\uff1a \u7528\u4e8e\u5c06\u67d0\u4e9b\u5c42\u7684\u67d0\u4e9b\u6ce8\u610f\u529b\u8ba1\u7b97\u65e0\u6548\u5316\uff1b<\/li><li><code><strong>inputs_embeds<\/strong><\/code>\uff1a \u5982\u679c\u63d0\u4f9b\u4e86\uff0c\u90a3\u5c31\u4e0d\u9700\u8981<code>input_ids<\/code>\uff0c\u8de8\u8fc7embedding lookup\u8fc7\u7a0b\u76f4\u63a5\u4f5c\u4e3aEmbedding\u8fdb\u5165Encoder\u8ba1\u7b97\uff1b<\/li><li><code><strong>encoder_hidden_states<\/strong><\/code>\uff1a \u8fd9\u4e00\u90e8\u5206\u5728BertModel\u914d\u7f6e\u4e3adecoder\u65f6\u8d77\u4f5c\u7528\uff0c\u5c06\u6267\u884ccross-attention\u800c\u4e0d\u662fself-attention\uff1b<\/li><li><code><strong>encoder_attention_mask<\/strong><\/code>\uff1a \u540c\u4e0a\uff0c\u5728cross-attention\u4e2d\u7528\u4e8e\u6807\u8bb0encoder\u7aef\u8f93\u5165\u7684padding\uff1b<\/li><li><code><strong>past_key_values<\/strong><\/code>\uff1a\u8fd9\u4e2a\u53c2\u6570\u8c8c\u4f3c\u662f\u628a\u9884\u5148\u8ba1\u7b97\u597d\u7684K-V\u4e58\u79ef\u4f20\u5165\uff0c\u4ee5\u964d\u4f4ecross-attention\u7684\u5f00\u9500\uff08\u56e0\u4e3a\u539f\u672c\u8fd9\u90e8\u5206\u662f\u91cd\u590d\u8ba1\u7b97\uff09\uff1b<\/li><li><code><strong>use_cache<\/strong><\/code>\uff1a \u5c06\u4fdd\u5b58\u4e0a\u4e00\u4e2a\u53c2\u6570\u5e76\u4f20\u56de\uff0c\u52a0\u901fdecoding\uff1b<\/li><li><strong><code>output_attentions<\/code><\/strong>\uff1a\u662f\u5426\u8fd4\u56de\u4e2d\u95f4\u6bcf\u5c42\u7684attention\u8f93\u51fa\uff1b<\/li><li><strong><code>output_hidden_states<\/code><\/strong>\uff1a\u662f\u5426\u8fd4\u56de\u4e2d\u95f4\u6bcf\u5c42\u7684\u8f93\u51fa\uff1b<\/li><li><strong><code>return_dict<\/code><\/strong>\uff1a\u662f\u5426\u6309\u952e\u503c\u5bf9\u7684\u5f62\u5f0f\uff08ModelOutput\u7c7b\uff0c\u4e5f\u53ef\u4ee5\u5f53\u4f5ctuple\u7528\uff09\u8fd4\u56de\u8f93\u51fa\uff0c\u9ed8\u8ba4\u4e3a\u771f\u3002<\/li><\/ul>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u6ce8\u610f\uff0c\u8fd9\u91cc\u7684head_mask\u5bf9\u6ce8\u610f\u529b\u8ba1\u7b97\u7684\u65e0\u6548\u5316\uff0c\u548c\u4e0b\u6587\u63d0\u5230\u7684\u6ce8\u610f\u529b\u5934\u526a\u679d\u4e0d\u540c\uff0c\u800c\u4ec5\u4ec5\u628a\u67d0\u4e9b\u6ce8\u610f\u529b\u7684\u8ba1\u7b97\u7ed3\u679c\u7ed9\u4e58\u4ee5\u8fd9\u4e00\u7cfb\u6570\u3002<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u8fd4\u56de\u90e8\u5206\u5982\u4e0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># BertModel\u7684\u524d\u5411\u4f20\u64ad\u8fd4\u56de\u90e8\u5206<\/em>\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs&#91;1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )<\/code><\/pre>\n\n\n\n<p>\u53ef\u4ee5\u770b\u51fa\uff0c\u8fd4\u56de\u503c\u4e0d\u4f46\u5305\u542b\u4e86encoder\u548cpooler\u7684\u8f93\u51fa\uff0c\u4e5f\u5305\u542b\u4e86\u5176\u4ed6\u6307\u5b9a\u8f93\u51fa\u7684\u90e8\u5206\uff08hidden_states\u548cattention\u7b49\uff0c\u8fd9\u4e00\u90e8\u5206\u5728<code>encoder_outputs[1:]<\/code>\uff09\u65b9\u4fbf\u53d6\u7528\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># BertEncoder\u7684\u524d\u5411\u4f20\u64ad\u8fd4\u56de\u90e8\u5206\uff0c\u5373\u4e0a\u9762\u7684encoder_outputs<\/em>\n        if not return_dict:\n            return tuple(\n                v\n                for v in &#91;\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )<\/code><\/pre>\n\n\n\n<p>\u6b64\u5916\uff0cBertModel\u8fd8\u6709\u4ee5\u4e0b\u7684\u65b9\u6cd5\uff0c\u65b9\u4fbfBERT\u73a9\u5bb6\u8fdb\u884c\u5404\u79cd\u9a9a\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<ol><li><code><strong>get_input_embeddings<\/strong><\/code>\uff1a\u63d0\u53d6embedding\u4e2d\u7684word_embeddings\u5373\u8bcd\u5411\u91cf\u90e8\u5206\uff1b<\/li><li><code><strong>set_input_embeddings<\/strong><\/code>\uff1a\u4e3aembedding\u4e2d\u7684word_embeddings\u8d4b\u503c\uff1b<\/li><li><code><strong>_prune_heads<\/strong><\/code>\uff1a\u63d0\u4f9b\u4e86\u5c06\u6ce8\u610f\u529b\u5934\u526a\u679d\u7684\u51fd\u6570\uff0c\u8f93\u5165\u4e3a<code>{layer_num: list of heads to prune in this layer}<\/code>\u7684\u5b57\u5178\uff0c\u53ef\u4ee5\u5c06\u6307\u5b9a\u5c42\u7684\u67d0\u4e9b\u6ce8\u610f\u529b\u5934\u526a\u679d\u3002<\/li><\/ol>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u526a\u679d\u662f\u4e00\u4e2a\u590d\u6742\u7684\u64cd\u4f5c\uff0c\u9700\u8981\u5c06\u4fdd\u7559\u7684\u6ce8\u610f\u529b\u5934\u90e8\u5206\u7684Wq\u3001Kq\u3001Vq\u548c\u62fc\u63a5\u540e\u5168\u8fde\u63a5\u90e8\u5206\u7684\u6743\u91cd\u62f7\u8d1d\u5230\u4e00\u4e2a\u65b0\u7684\u8f83\u5c0f\u7684\u6743\u91cd\u77e9\u9635\uff08\u6ce8\u610f\u5148\u7981\u6b62grad\u518d\u62f7\u8d1d\uff09\uff0c\u5e76\u5b9e\u65f6\u8bb0\u5f55\u88ab\u526a\u6389\u7684\u5934\u4ee5\u9632\u4e0b\u6807\u51fa\u9519\u3002\u5177\u4f53\u53c2\u8003<code>BertAttention<\/code>\u90e8\u5206\u7684<code>prune_heads<\/code>\u65b9\u6cd5\u3002<\/strong><\/p><\/blockquote>\n\n\n\n<h3>2.1 BertEmbeddings<\/h3>\n\n\n\n<p>\u5305\u542b\u4e09\u4e2a\u90e8\u5206\u6c42\u548c\u5f97\u5230\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic3.zhimg.com\/v2-58b65365587f269bc76358016414dc26_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<ol><li><strong>word_embeddings<\/strong>\uff0c\u4e0a\u6587\u4e2dsubword\u5bf9\u5e94\u7684\u5d4c\u5165\u3002<\/li><li><strong>token_type_embeddings<\/strong>\uff0c\u7528\u4e8e\u8868\u793a\u5f53\u524d\u8bcd\u6240\u5728\u7684\u53e5\u5b50\uff0c\u8f85\u52a9\u533a\u522b\u53e5\u5b50\u4e0epadding\u3001\u53e5\u5b50\u5bf9\u95f4\u7684\u5dee\u5f02\u3002<\/li><li><strong>position_embeddings<\/strong>\uff0c\u53e5\u5b50\u4e2d\u6bcf\u4e2a\u8bcd\u7684\u4f4d\u7f6e\u5d4c\u5165\uff0c\u7528\u4e8e\u533a\u522b\u8bcd\u7684\u987a\u5e8f\u3002\u548ctransformer\u8bba\u6587\u4e2d\u7684\u8bbe\u8ba1\u4e0d\u540c\uff0c\u8fd9\u4e00\u5757\u662f\u8bad\u7ec3\u51fa\u6765\u7684\uff0c\u800c\u4e0d\u662f\u901a\u8fc7Sinusoidal\u51fd\u6570\u8ba1\u7b97\u5f97\u5230\u7684\u56fa\u5b9a\u5d4c\u5165\u3002\u4e00\u822c\u8ba4\u4e3a\u8fd9\u79cd\u5b9e\u73b0\u4e0d\u5229\u4e8e\u62d3\u5c55\u6027\uff08\u96be\u4ee5\u76f4\u63a5\u8fc1\u79fb\u5230\u66f4\u957f\u7684\u53e5\u5b50\u4e2d\uff09\u3002<\/li><\/ol>\n\n\n\n<p>\u4e09\u4e2aembedding\u4e0d\u5e26\u6743\u91cd\u76f8\u52a0\uff0c\u5e76\u901a\u8fc7\u4e00\u5c42LayerNorm+dropout\u540e\u8f93\u51fa\uff0c\u5176\u5927\u5c0f\u4e3a<code>(batch_size, sequence_length, hidden_size)<\/code>\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u8fd9\u91cc\u4e3a\u4ec0\u4e48\u8981\u7528LayerNorm+Dropout\u5462\uff1f\u4e3a\u4ec0\u4e48\u8981\u7528LayerNorm\u800c\u4e0d\u662fBatchNorm\uff1f\u53ef\u4ee5\u53c2\u8003\u4e00\u4e2a\u4e0d\u9519\u7684\u56de\u7b54\uff1a<\/strong><\/p><\/blockquote>\n\n\n\n<p><a target=\"_blank\" href=\"https:\/\/www.zhihu.com\/question\/395811291\/answer\/1260290120\" rel=\"noreferrer noopener\">transformer \u4e3a\u4ec0\u4e48\u4f7f\u7528 layer normalization\uff0c\u800c\u4e0d\u662f\u5176\u4ed6\u7684\u5f52\u4e00\u5316\u65b9\u6cd5\uff1f369 \u8d5e\u540c \u00b7 15 \u8bc4\u8bba\u56de\u7b54<\/a><\/p>\n\n\n\n<h3>2.2 BertEncoder<\/h3>\n\n\n\n<p>\u5305\u542b\u591a\u5c42BertLayer\uff0c\u8fd9\u4e00\u5757\u672c\u8eab\u6ca1\u6709\u7279\u522b\u9700\u8981\u8bf4\u660e\u7684\u5730\u65b9\uff0c\u4e0d\u8fc7\u6709\u4e00\u4e2a\u7ec6\u8282\u503c\u5f97\u53c2\u8003\uff1a<\/p>\n\n\n\n<p>\u5229\u7528<strong>gradient checkpointing<\/strong>\u6280\u672f\u4ee5\u964d\u4f4e\u8bad\u7ec3\u65f6\u7684\u663e\u5b58\u5360\u7528\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1agradient checkpointing\u5373\u68af\u5ea6\u68c0\u67e5\u70b9\uff0c\u901a\u8fc7\u51cf\u5c11\u4fdd\u5b58\u7684\u8ba1\u7b97\u56fe\u8282\u70b9\u538b\u7f29\u6a21\u578b\u5360\u7528\u7a7a\u95f4\uff0c\u4f46\u662f\u5728\u8ba1\u7b97\u68af\u5ea6\u7684\u65f6\u5019\u9700\u8981\u91cd\u65b0\u8ba1\u7b97\u6ca1\u6709\u5b58\u50a8\u7684\u503c\uff0c\u53c2\u8003\u8bba\u6587\u300aTraining Deep Nets with Sublinear Memory Cost\u300b\uff0c\u8fc7\u7a0b\u5982\u4e0b\u793a\u610f\u56fe\uff1a<\/strong><\/p><\/blockquote>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic2.zhimg.com\/v2-24dfc50af29690e09dd5e8cc3319847d_b.jpg\" alt=\"\u52a8\u56fe\u5c01\u9762\"\/><\/figure>\n\n\n\n<p>\u5728BertEncoder\u4e2d\uff0cgradient checkpoint\u662f\u901a\u8fc7<code>torch.utils.checkpoint.checkpoint<\/code>\u5b9e\u73b0\u7684\uff0c\u4f7f\u7528\u8d77\u6765\u6bd4\u8f83\u65b9\u4fbf\uff0c\u53ef\u4ee5\u53c2\u8003\u6587\u6863\uff1a<a target=\"_blank\" href=\"https:\/\/pytorch.org\/docs\/stable\/checkpoint.html\" rel=\"noreferrer noopener\">torch.utils.checkpoint &#8211; PyTorch 1.8.1 documentation\u200bpytorch.org\/docs\/stable\/checkpoint.html<\/a><\/p>\n\n\n\n<p>\u8fd9\u4e00\u673a\u5236\u7684\u5177\u4f53\u5b9e\u73b0\u6bd4\u8f83\u590d\u6742\uff08\u6ca1\u770b\u61c2\uff09\uff0c\u5728\u6b64\u4e0d\u4f5c\u5c55\u5f00\u3002<\/p>\n\n\n\n<p>\u518d\u5f80\u6df1\u4e00\u5c42\u8d70\uff0c\u5c31\u8fdb\u5165\u4e86Encoder\u7684\u67d0\u4e00\u5c42\uff1a<\/p>\n\n\n\n<h3>2.2.1 BertLayer<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u5c42\u5305\u88c5\u4e86BertAttention\u548cBertIntermediate+BertOutput\uff08\u5373Attention\u540e\u7684FFN\u90e8\u5206\uff09\uff0c\u4ee5\u53ca\u8fd9\u91cc\u76f4\u63a5\u5ffd\u7565\u7684cross-attention\u90e8\u5206\uff08\u5c06BERT\u4f5c\u4e3aDecoder\u65f6\u6d89\u53ca\u7684\u90e8\u5206\uff09\u3002<\/p>\n\n\n\n<p>\u7406\u8bba\u4e0a\uff0c\u8fd9\u91cc\u987a\u5e8f\u8c03\u7528\u4e09\u4e2a\u5b50\u6a21\u5757\u5c31\u53ef\u4ee5\uff0c\u6ca1\u6709\u4ec0\u4e48\u503c\u5f97\u8bf4\u660e\u7684\u5730\u65b9\u3002<\/p>\n\n\n\n<p>\u7136\u800c\u8fd9\u91cc\u53c8\u51fa\u73b0\u4e86\u4e00\u4e2a<strong>\u7ec6\u8282<\/strong>\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic3.zhimg.com\/80\/v2-ac4943140c0ce63842cf4413f5511246_1440w.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># \u8fd9\u662fforward\u7684\u4e00\u90e8\u5206<\/em>\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        outputs = self_attention_outputs&#91;1:]  <em># add self attentions if we output attention weights<\/em>\n\n        <em># \u4e2d\u95f4\u7701\u7565\u4e00\u90e8\u5206\u2026\u2026<\/em>\n\n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output\n        )\n        outputs = (layer_output,) + outputs\n\n        <em># \u7701\u7565\u4e00\u90e8\u5206\u2026\u2026<\/em>\n\n        return outputs\n\n    <em># \u8fd9\u662ffeed_forward_chunk\u7684\u90e8\u5206<\/em>\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output<\/code><\/pre>\n\n\n\n<p>\u770b\u5230\u4e0a\u9762\u90a3\u4e2a<code>apply_chunking_to_forward<\/code>\u548c<code>feed_forward_chunk<\/code>\u4e86\u5417\uff08\u4e3a\u4ec0\u4e48\u8981\u6574\u8fd9\u4e48\u590d\u6742\uff0c\u76f4\u63a5\u8c03\u7528\u5b83\u4e0d\u9999\u5417\uff09\uff1f<\/p>\n\n\n\n<p>\u90a3\u4e48\u8fd9\u4e2a<code>apply_chunking_to_forward<\/code>\u5230\u5e95\u662f\u5565\uff1f\u6df1\u5165\u770b\u770b\u2026\u2026<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def apply_chunking_to_forward(\n    forward_fn: Callable&#91;..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors\n) -&gt; torch.Tensor:\n    \"\"\"\n    This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the\n    dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.\n\n    If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as\n    directly applying :obj:`forward_fn` to :obj:`input_tensors`.\n    ...\n    \"\"\"<\/code><\/pre>\n\n\n\n<p>\u54e6\uff0c\u53c8\u662f\u4e00\u4e2a\u8282\u7ea6\u663e\u5b58\u7684\u6280\u672f\u2014\u2014\u5305\u88c5\u4e86\u4e00\u4e2a\u5207\u5206\u5c0fbatch\u6216\u8005\u4f4e\u7ef4\u6570\u64cd\u4f5c\u7684\u529f\u80fd\uff1a\u8fd9\u91cc\u53c2\u6570<code>chunk_size<\/code>\u5176\u5b9e\u5c31\u662f\u5207\u5206\u7684batch\u5927\u5c0f\uff0c\u800c<code>chunk_dim<\/code>\u5c31\u662f\u4e00\u6b21\u8ba1\u7b97\u7ef4\u6570\u7684\u5927\u5c0f\uff0c\u6700\u540e\u62fc\u63a5\u8d77\u6765\u8fd4\u56de\u3002<\/p>\n\n\n\n<p>\u4e0d\u8fc7\uff0c\u5728\u9ed8\u8ba4\u64cd\u4f5c\u4e2d\u4e0d\u4f1a\u7279\u610f\u8bbe\u7f6e\u8fd9\u4e24\u4e2a\u503c\uff08\u5728\u6e90\u4ee3\u7801\u4e2d\u9ed8\u8ba4\u4e3a0\u548c1\uff09\uff0c\u6240\u4ee5\u4f1a\u76f4\u63a5\u7b49\u6548\u4e8e\u6b63\u5e38\u7684forward\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<p>\u7ee7\u7eed\u5f80\u4e0b\u6df1\u5165\uff0c\u5c31\u662fTransformer\u7684\u6838\u5fc3\uff1aBertAttention\u90e8\u5206\uff0c\u4ee5\u53ca\u7d27\u968f\u5176\u540e\u7684FFN\u90e8\u5206\u3002<\/p>\n\n\n\n<h3>2.2.1.1 BertAttention<\/h3>\n\n\n\n<p>\u672c\u4ee5\u4e3aattention\u7684\u5b9e\u73b0\u5c31\u5728\u8fd9\u91cc\uff0c\u6ca1\u60f3\u5230\u8fd8\u8981\u518d\u4e0b\u4e00\u5c42\u2026\u2026\u5176\u4e2d\uff0cself\u6210\u5458\u5c31\u662f\u591a\u5934\u6ce8\u610f\u529b\u7684\u5b9e\u73b0\uff0c\u800coutput\u6210\u5458\u5b9e\u73b0attention\u540e\u7684\u5168\u8fde\u63a5+dropout+residual+LayerNorm\u4e00\u7cfb\u5217\u64cd\u4f5c\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.self = BertSelfAttention(config)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()<\/code><\/pre>\n\n\n\n<p>\u9996\u5148\u8fd8\u662f\u56de\u5230\u8fd9\u4e00\u5c42\u3002\u8fd9\u91cc\u51fa\u73b0\u4e86\u4e0a\u6587\u63d0\u5230\u7684\u526a\u679d\u64cd\u4f5c\uff0c\u5373<code>prune_heads<\/code>\u65b9\u6cd5\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads\n        )\n\n        <em># Prune linear layers<\/em>\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        <em># Update hyper params and store pruned heads<\/em>\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\n        self.pruned_heads = self.pruned_heads.union(heads) <\/code><\/pre>\n\n\n\n<p>\u8fd9\u91cc\u7684\u5177\u4f53\u5b9e\u73b0\u6982\u62ec\u5982\u4e0b\uff1a<\/p>\n\n\n\n<ul><li><code>find_pruneable_heads_and_indices<\/code>\u662f\u5b9a\u4f4d\u9700\u8981\u526a\u6389\u7684head\uff0c\u4ee5\u53ca\u9700\u8981\u4fdd\u7559\u7684\u7ef4\u5ea6\u4e0b\u6807index\uff1b<\/li><li><code>prune_linear_layer<\/code>\u5219\u8d1f\u8d23\u5c06Wk\/Wq\/Wv\u6743\u91cd\u77e9\u9635\uff08\u8fde\u540cbias\uff09\u4e2d\u6309\u7167index\u4fdd\u7559\u6ca1\u6709\u88ab\u526a\u679d\u7684\u7ef4\u5ea6\u540e\u8f6c\u79fb\u5230\u65b0\u7684\u77e9\u9635\u3002<\/li><\/ul>\n\n\n\n<p>\u63a5\u4e0b\u6765\u5c31\u5230\u91cd\u5934\u620f\u2014\u2014Self-Attention\u7684\u5177\u4f53\u5b9e\u73b0\u3002<\/p>\n\n\n\n<h3>2.2.1.1.1 BertSelfAttention<\/h3>\n\n\n\n<p><strong>\u9884\u8b66\uff1a\u8fd9\u4e00\u5757\u53ef\u4ee5\u8bf4\u662f\u6a21\u578b\u7684\u6838\u5fc3\u533a\u57df\uff0c\u4e5f\u662f\u552f\u4e00\u6d89\u53ca\u5230\u516c\u5f0f\u7684\u5730\u65b9\uff0c\u6240\u4ee5\u5c06\u8d34\u51fa\u5927\u91cf\u4ee3\u7801\u3002<\/strong><\/p>\n\n\n\n<p>\u521d\u59cb\u5316\u90e8\u5206\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertSelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size \/ config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)\n\n        self.is_decoder = config.is_decoder<\/code><\/pre>\n\n\n\n<ul><li>\u9664\u6389\u719f\u6089\u7684query\u3001key\u3001value\u4e09\u4e2a\u6743\u91cd\u548c\u4e00\u4e2adropout\uff0c\u8fd9\u91cc\u8fd8\u6709\u4e00\u4e2a\u8c1c\u4e00\u6837\u7684position_embedding_type\uff0c\u4ee5\u53cadecoder\u6807\u8bb0\uff08\u5f53\u7136\uff0c\u6211\u4e0d\u6253\u7b97\u4ecb\u7ecdcross-attenton\u90e8\u5206\uff09\uff1b<\/li><li>\u6ce8\u610f\uff0chidden_size\u548call_head_size\u5728\u4e00\u5f00\u59cb\u662f\u4e00\u6837\u7684\u3002\u81f3\u4e8e\u4e3a\u4ec0\u4e48\u8981\u770b\u8d77\u6765\u591a\u6b64\u4e00\u4e3e\u5730\u8bbe\u7f6e\u8fd9\u4e00\u4e2a\u53d8\u91cf\u2014\u2014\u663e\u7136\u662f\u56e0\u4e3a\u4e0a\u9762\u90a3\u4e2a\u526a\u679d\u51fd\u6570\uff0c\u526a\u6389\u51e0\u4e2aattention head\u4ee5\u540eall_head_size\u81ea\u7136\u5c31\u5c0f\u4e86\uff1b<\/li><li>hidden_size\u5fc5\u987b\u662fnum_attention_heads\u7684\u6574\u6570\u500d\uff0c\u4ee5bert-base\u4e3a\u4f8b\uff0c\u6bcf\u4e2aattention\u5305\u542b12\u4e2ahead\uff0chidden_size\u662f768\uff0c\u6240\u4ee5\u6bcf\u4e2ahead\u5927\u5c0f\u5373attention_head_size=768\/12=64\uff1b<\/li><li>position_embedding_type\u662f\u4ec0\u4e48\uff1f\u7ee7\u7eed\u5f80\u4e0b\u770b\u5c31\u77e5\u9053\u4e86\u2026\u2026<\/li><\/ul>\n\n\n\n<p>\u7136\u540e\u662f\u91cd\u70b9\uff0c\u4e5f\u5c31\u662f\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u3002<\/p>\n\n\n\n<p>\u9996\u5148\u56de\u987e\u4e00\u4e0bmulti-head self-attention\u7684\u57fa\u672c\u516c\u5f0f\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic4.zhimg.com\/v2-0c1ffd5ec70918a7c6c42fc7aafd7b0b_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic3.zhimg.com\/80\/v2-a039199ba67846ca033f76afd3582476_1440w.png\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u5176\u4e2d&nbsp;|h|&nbsp;\u8868\u793a\u6ce8\u610f\u529b\u5934\u7684\u4e2a\u6570\uff0c&nbsp;[\u22c5]&nbsp;\u8868\u793a\u5411\u91cf\u62fc\u63a5\uff0c&nbsp;Wo\u2208R|h|dv\u00d7dx&nbsp;\u3002<\/p>\n\n\n\n<p>\u800c\u8fd9\u4e9b\u6ce8\u610f\u529b\u5934\uff0c\u4f17\u6240\u5468\u77e5\u662f\u5e76\u884c\u8ba1\u7b97\u7684\uff0c\u6240\u4ee5\u4e0a\u9762\u7684query\u3001key\u3001value\u4e09\u4e2a\u6743\u91cd\u662f\u552f\u4e00\u7684\u2014\u2014\u8fd9\u5e76\u4e0d\u662f\u6240\u6709heads\u5171\u4eab\u4e86\u6743\u91cd\uff0c\u800c\u662f\u201c\u62fc\u63a5\u201d\u8d77\u6765\u4e86\u3002<\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u539f\u8bba\u6587\u4e2d\u591a\u5934\u7684\u7406\u7531\u4e3aMulti-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.\u800c\u53e6\u4e00\u4e2a\u6bd4\u8f83\u9760\u8c31\u7684\u5206\u6790\u6709\uff1a<\/strong><\/p><\/blockquote>\n\n\n\n<p><a target=\"_blank\" href=\"https:\/\/www.zhihu.com\/question\/341222779\/answer\/814111138\" rel=\"noreferrer noopener\">\u4e3a\u4ec0\u4e48Transformer \u9700\u8981\u8fdb\u884c Multi-head Attention\uff1f1036 \u8d5e\u540c \u00b7 46 \u8bc4\u8bba\u56de\u7b54<\/a><\/p>\n\n\n\n<p>\u770b\u770bforward\u65b9\u6cd5\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>    def transpose_for_scores(self, x):\n        new_x_shape = x.size()&#91;:-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        mixed_query_layer = self.query(hidden_states)\n\n        <em># \u7701\u7565\u4e00\u90e8\u5206cross-attention\u7684\u8ba1\u7b97<\/em>\n        key_layer = self.transpose_for_scores(self.key(hidden_states))\n        value_layer = self.transpose_for_scores(self.value(hidden_states))\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        <em># Take the dot product between \"query\" and \"key\" to get the raw attention scores.<\/em>\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n        <em># ...<\/em><\/code><\/pre>\n\n\n\n<ul><li>\u8fd9\u91cc\u7684transpose_for_scores\u7528\u6765\u628ahidden_size\u62c6\u6210\u591a\u4e2a\u5934\u8f93\u51fa\u7684\u5f62\u72b6\uff0c\u5e76\u4e14\u5c06\u4e2d\u95f4\u4e24\u7ef4\u8f6c\u7f6e\u4ee5\u8fdb\u884c\u77e9\u9635\u76f8\u4e58\uff1b<\/li><li>\u8fd9\u91cckey_layer\/value_layer\/query_layer\u7684\u5f62\u72b6\u4e3a\uff1a<code>(batch_size, num_attention_heads, sequence_length, attention_head_size)<\/code>\uff1b<\/li><li>\u8fd9\u91ccattention_scores\u7684\u5f62\u72b6\u4e3a\uff1a<code>(batch_size, num_attention_heads, sequence_length, sequence_length)<\/code>\uff0c\u7b26\u5408\u591a\u4e2a\u5934\u5355\u72ec\u8ba1\u7b97\u83b7\u5f97\u7684attention map\u5f62\u72b6\u3002<\/li><\/ul>\n\n\n\n<p>\u5230\u8fd9\u91cc\u5b9e\u73b0\u4e86K\u4e0eQ\u76f8\u4e58\uff0c\u83b7\u5f97raw attention scores\u7684\u90e8\u5206\uff0c\u6309\u516c\u5f0f\u63a5\u4e0b\u6765\u5e94\u8be5\u662f\u6309dk\u8fdb\u884cscaling\u5e76\u505asoftmax\u7684\u64cd\u4f5c\u3002\u7136\u800c\u2014\u2014<\/p>\n\n\n\n<p>\u5148\u51fa\u73b0\u5728\u773c\u524d\u7684\u662f\u4e00\u4e2a\u5947\u602a\u7684<code>positional_embedding<\/code>\uff0c\u4ee5\u53ca\u4e00\u5806\u7231\u56e0\u65af\u5766\u6c42\u548c\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># ...<\/em>\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            seq_length = hidden_states.size()&#91;1]\n            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)\n            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)\n            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  <em># fp16 compatibility<\/em>\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\"bhld,lrd-&gt;bhlr\", query_layer, positional_embedding)\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\"bhld,lrd-&gt;bhlr\", query_layer, positional_embedding)\n                relative_position_scores_key = torch.einsum(\"bhrd,lrd-&gt;bhlr\", key_layer, positional_embedding)\n                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key\n        <em># ...<\/em><\/code><\/pre>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u5173\u4e8e\u7231\u56e0\u65af\u5766\u6c42\u548c\u7ea6\u5b9a\uff0c\u53c2\u8003\u4ee5\u4e0b\u6587\u6863<\/strong><\/p><\/blockquote>\n\n\n\n<p><a target=\"_blank\" href=\"https:\/\/pytorch.org\/docs\/stable\/generated\/torch.einsum.html\" rel=\"noreferrer noopener\">torch.einsum &#8211; PyTorch 1.8.1 documentation\u200bpytorch.org\/docs\/stable\/generated\/torch.einsum.html<\/a><\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u8fd9\u91cc\u7684<code>positional_embedding<\/code>\u5f15\u5165\u4e86attention map\u4e2d\u7684\u4f4d\u7f6e\u5d4c\u5165\u2014\u2014\u4e3a\u4ec0\u4e48\u8981\u8fd9\u4e48\u505a\u5462\uff1f\u6211\u76ee\u524d\u8fd8\u6ca1\u641e\u660e\u767d\u2026\u2026<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u5bf9\u4e8e\u4e0d\u540c\u7684<code>positional_embedding_type<\/code>\uff0c\u6709\u4e09\u79cd\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<ul><li><code>absolute<\/code>\uff1a\u9ed8\u8ba4\u503c\uff0c\u8fd9\u90e8\u5206\u5c31\u4e0d\u7528\u5904\u7406\uff1b<\/li><li><code>relative_key<\/code>\uff1a\u5bf9key_layer\u4f5c\u5904\u7406\uff0c\u5c06\u5176\u4e0e\u8fd9\u91cc\u7684<code>positional_embedding<\/code>\u548ckey\u77e9\u9635\u76f8\u4e58\u4f5c\u4e3akey\u76f8\u5173\u7684\u4f4d\u7f6e\u7f16\u7801\uff1b<\/li><li><code>relative_key_query<\/code>\uff1a\u5bf9key\u548cvalue\u90fd\u8fdb\u884c\u76f8\u4e58\u4ee5\u4f5c\u4e3a\u4f4d\u7f6e\u7f16\u7801\u3002<\/li><\/ul>\n\n\n\n<p>\u6682\u65f6\u8df3\u8fc7\u8fd9\u4e00\u8ff7\u60d1\u7684\u90e8\u5206\uff0c\u56de\u5230\u6b63\u5e38attention\u7684\u6d41\u7a0b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># ...<\/em>\n        attention_scores = attention_scores \/ math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            <em># Apply the attention mask is (precomputed for all layers in BertModel forward() function)<\/em>\n            attention_scores = attention_scores + attention_mask  <em># \u8fd9\u91cc\u4e3a\u4ec0\u4e48\u662f+\u800c\u4e0d\u662f*\uff1f<\/em>\n\n        <em># Normalize the attention scores to probabilities.<\/em>\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n\n        <em># This is actually dropping out entire tokens to attend to, which might<\/em>\n        <em># seem a bit unusual, but is taken from the original Transformer paper.<\/em>\n        attention_probs = self.dropout(attention_probs)\n\n        <em># Mask heads if we want to<\/em>\n        if head_mask is not None:\n            attention_probs = attention_probs * head_mask\n\n        context_layer = torch.matmul(attention_probs, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()&#91;:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)\n\n        <em># \u7701\u7565decoder\u8fd4\u56de\u503c\u90e8\u5206\u2026\u2026<\/em>\n        return outputs<\/code><\/pre>\n\n\n\n<p><strong>\u91cd\u5927\u7591\u95ee\uff1a\u8fd9\u91cc\u7684<code>attention_scores = attention_scores + attention_mask<\/code>\u662f\u5728\u505a\u4ec0\u4e48\uff1f\u96be\u9053\u4e0d\u5e94\u8be5\u662f\u4e58mask\u5417\uff1f<\/strong><\/p>\n\n\n\n<p><strong>\u56e0\u4e3a\u8fd9\u91cc\u7684attention_mask\u5df2\u7ecf\u3010\u88ab\u52a8\u8fc7\u624b\u811a\u3011\uff0c\u5c06\u539f\u672c\u4e3a1\u7684\u90e8\u5206\u53d8\u4e3a0\uff0c\u800c\u539f\u672c\u4e3a0\u7684\u90e8\u5206\uff08\u5373padding\uff09\u53d8\u4e3a\u4e00\u4e2a\u8f83\u5927\u7684\u8d1f\u6570\uff0c\u8fd9\u6837\u76f8\u52a0\u5c31\u5f97\u5230\u4e86\u4e00\u4e2a\u8f83\u5927\u7684\u8d1f\u503c\uff1a<\/strong><\/p>\n\n\n\n<ul><li><strong>\u81f3\u4e8e\u4e3a\u4ec0\u4e48\u8981\u7528\u3010\u4e00\u4e2a\u8f83\u5927\u7684\u8d1f\u6570\u3011\uff1f\u56e0\u4e3a\u8fd9\u6837\u4e00\u6765\u7ecf\u8fc7softmax\u64cd\u4f5c\u4ee5\u540e\u8fd9\u4e00\u9879\u5c31\u4f1a\u53d8\u6210\u63a5\u8fd10\u7684\u5c0f\u6570\u3002<\/strong><\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>(Pdb) attention_mask\ntensor(&#91;&#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n        &#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n        &#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n        ...,\n        &#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n        &#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],\n        &#91;&#91;&#91;    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]]],\n       device='cuda:0')<\/code><\/pre>\n\n\n\n<p>\u90a3\u4e48\uff0c\u8fd9\u4e00\u6b65\u662f\u5728\u54ea\u91cc\u6267\u884c\u7684\u5462\uff1f<\/p>\n\n\n\n<p>\u6211\u5728<code>modeling_bert.py<\/code>\u4e2d\u6ca1\u6709\u627e\u5230\u7b54\u6848\uff0c\u4f46\u662f\u5728<code>modeling_utils.py<\/code>\u4e2d\u627e\u5230\u4e86\u4e00\u4e2a\u7279\u522b\u7684\u7c7b\uff1a<code>class ModuleUtilsMixin<\/code>\uff0c\u5728\u5b83\u7684<code>get_extended_attention_mask<\/code>\u65b9\u6cd5\u4e2d\u53d1\u73b0\u4e86\u7aef\u502a\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple&#91;int], device: device) -&gt; Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple&#91;int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        <em># \u7701\u7565\u4e00\u90e8\u5206\u2026\u2026<\/em>\n\n        <em># Since attention_mask is 1.0 for positions we want to attend and 0.0 for<\/em>\n        <em># masked positions, this operation will create a tensor which is 0.0 for<\/em>\n        <em># positions we want to attend and -10000.0 for masked positions.<\/em>\n        <em># Since we are adding it to the raw scores before the softmax, this is<\/em>\n        <em># effectively the same as removing these entirely.<\/em>\n        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  <em># fp16 compatibility<\/em>\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask<\/code><\/pre>\n\n\n\n<p>\u90a3\u4e48\uff0c\u8fd9\u4e2a\u51fd\u6570\u662f\u5728\u4ec0\u4e48\u65f6\u5019\u88ab\u8c03\u7528\u7684\u5462\uff1f\u548c<code>BertModel<\/code>\u6709\u4ec0\u4e48\u5173\u7cfb\u5462\uff1f<\/p>\n\n\n\n<p>OK\uff0c\u8fd9\u91cc\u6d89\u53ca\u5230BertModel\u7684\u7ee7\u627f\u7ec6\u8282\u4e86\uff1a<code>BertModel<\/code>\u7ee7\u627f\u81ea<code>BertPreTrainedModel&nbsp;<\/code>\uff0c\u540e\u8005\u7ee7\u627f\u81ea<code>PreTrainedModel<\/code>\uff0c\u800c<code>PreTrainedModel<\/code>\u7ee7\u627f\u81ea<code>[nn.Module, ModuleUtilsMixin, GenerationMixin]&nbsp;<\/code>\u4e09\u4e2a\u57fa\u7c7b\u3002\u2014\u2014\u597d\u590d\u6742\u7684\u5c01\u88c5\uff01<\/p>\n\n\n\n<p><strong>\u8fd9\u4e5f\u5c31\u662f\u8bf4\uff0c&nbsp;<code>BertModel<\/code>\u5fc5\u7136\u5728\u4e2d\u95f4\u7684\u67d0\u4e2a\u6b65\u9aa4\u5bf9\u539f\u59cb\u7684<code>attention_mask<\/code>\u8c03\u7528\u4e86<code>get_extended_attention_mask<\/code>&nbsp;\uff0c\u5bfc\u81f4<code>attention_mask<\/code>\u4ece\u539f\u59cb\u7684<code>[1, 0]<\/code>\u53d8\u4e3a<code>[0, -1e4]<\/code>\u7684\u53d6\u503c\u3002<\/strong><\/p>\n\n\n\n<p>\u771f\u76f8\u53ea\u6709\u4e00\u4e2a\uff01<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic3.zhimg.com\/80\/v2-e018b03e763940c2db72bde830cba552_1440w.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u6700\u7ec8\u5728<code>BertModel<\/code>\u7684\u524d\u5411\u4f20\u64ad\u8fc7\u7a0b\u4e2d\u627e\u5230\u4e86\u8fd9\u4e00\u8c03\u7528\uff08\u7b2c944\u884c\uff09\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># We can provide a self-attention mask of dimensions &#91;batch_size, from_seq_length, to_seq_length]<\/em>\n        <em># ourselves in which case we just need to make it broadcastable to all heads.<\/em>\n        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)<\/code><\/pre>\n\n\n\n<p>\u95ee\u9898\u89e3\u51b3\u4e86\uff1a\u8fd9\u4e00\u65b9\u6cd5\u4e0d\u4f46\u5b9e\u73b0\u4e86\u6539\u53d8mask\u7684\u503c\uff0c\u8fd8\u5c06\u5176\u5e7f\u64ad\uff08broadcast\uff09\u4e3a\u53ef\u4ee5\u76f4\u63a5\u4e0eattention map\u76f8\u52a0\u7684\u5f62\u72b6\u3002<\/p>\n\n\n\n<p>\u4e0d\u6127\u662f\u4f60\uff0cHuggingFace\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic1.zhimg.com\/v2-c1a98dd0675445d9f7adbe3bc65a28f4_r.jpg\" alt=\"\"\/><figcaption>\u62b1\u8138\u866b\u53ef\u4e0d\u662f\u4e71\u53eb\u7684\uff01<\/figcaption><\/figure>\n\n\n\n<p>\u9664\u6b64\u4e4b\u5916\uff0c\u503c\u5f97\u6ce8\u610f\u7684\u7ec6\u8282\u6709\uff1a<\/p>\n\n\n\n<ul><li>\u6309\u7167\u6bcf\u4e2a\u5934\u7684\u7ef4\u5ea6\u8fdb\u884c\u7f29\u653e\uff0c\u5bf9\u4e8ebert-base\u5c31\u662f64\u7684\u5e73\u65b9\u6839\u53738\uff1b<\/li><li>attention_probs\u4e0d\u4f46\u505a\u4e86softmax\uff0c\u8fd8\u7528\u4e86\u4e00\u6b21dropout\uff0c\u8fd9\u662f\u62c5\u5fc3attention\u77e9\u9635\u592a\u7a20\u5bc6\u5417\u2026\u2026\u8fd9\u91cc\u4e5f\u63d0\u5230\u5f88\u4e0d\u5bfb\u5e38\uff0c\u4f46\u662f\u539f\u59cbTransformer\u8bba\u6587\u5c31\u662f\u8fd9\u4e48\u505a\u7684\uff1b<\/li><li>head_mask\u5c31\u662f\u4e4b\u524d\u63d0\u5230\u7684\u5bf9\u591a\u5934\u8ba1\u7b97\u7684mask\uff0c\u5982\u679c\u4e0d\u8bbe\u7f6e\u9ed8\u8ba4\u662f\u51681\uff0c\u5728\u8fd9\u91cc\u5c31\u4e0d\u4f1a\u8d77\u4f5c\u7528\uff1b<\/li><li>context_layer\u5373attention\u77e9\u9635\u4e0evalue\u77e9\u9635\u7684\u4e58\u79ef\uff0c\u539f\u59cb\u7684\u5927\u5c0f\u4e3a\uff1a<code>(batch_size, num_attention_heads, sequence_length, attention_head_size)<\/code>&nbsp;\uff1b<\/li><li>context_layer\u8fdb\u884c\u8f6c\u7f6e\u548cview\u64cd\u4f5c\u4ee5\u540e\uff0c\u5f62\u72b6\u5c31\u6062\u590d\u4e86<code>(batch_size, sequence_length, hidden_size)<\/code>\u3002<\/li><\/ul>\n\n\n\n<p>OK, that&#8217;s all for attention.<\/p>\n\n\n\n<h3>2.2.1.1.2 BertSelfOutput<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u5757\u64cd\u4f5c\u7565\u591a\u4f46\u4e0d\u590d\u6742\uff0c\u4e00\u76ee\u4e86\u7136\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states<\/code><\/pre>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u8fd9\u91cc\u53c8\u51fa\u73b0\u4e86LayerNorm\u548cDropout\u7684\u7ec4\u5408\uff0c\u53ea\u4e0d\u8fc7\u8fd9\u91cc\u662f\u5148Dropout\uff0c\u8fdb\u884c\u6b8b\u5dee\u8fde\u63a5\u540e\u518d\u8fdb\u884cLayerNorm\u3002\u81f3\u4e8e\u4e3a\u4ec0\u4e48\u8981\u505a\u6b8b\u5dee\u8fde\u63a5\uff0c\u6700\u76f4\u63a5\u7684\u76ee\u7684\u5c31\u662f\u964d\u4f4e\u7f51\u7edc\u5c42\u6570\u8fc7\u6df1\u5e26\u6765\u7684\u8bad\u7ec3\u96be\u5ea6\uff0c\u5bf9\u539f\u59cb\u8f93\u5165\u66f4\u52a0\u654f\u611f\uff5e<\/strong><\/p><\/blockquote>\n\n\n\n<h3>2.2.1.2 BertIntermediate<\/h3>\n\n\n\n<p>\u770b\u5b8c\u4e86BertAttention\uff0c\u5728Attention\u540e\u9762\u8fd8\u6709\u4e00\u4e2a\u5168\u8fde\u63a5+\u6fc0\u6d3b\u7684\u64cd\u4f5c\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN&#91;config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states<\/code><\/pre>\n\n\n\n<ul><li>\u8fd9\u91cc\u7684\u5168\u8fde\u63a5\u505a\u4e86\u4e00\u4e2a\u6269\u5c55\uff0c\u4ee5bert-base\u4e3a\u4f8b\uff0c\u6269\u5c55\u7ef4\u5ea6\u4e3a3072\uff0c\u662f\u539f\u59cb\u7ef4\u5ea6768\u76844\u500d\u4e4b\u591a\uff1b<\/li><\/ul>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u4e3a\u4ec0\u4e48\u8981\u8fc7\u4e00\u4e2aFFN\uff1f\u4e0d\u77e5\u9053\u2026\u2026\u8c37\u6b4c\u6700\u8fd1\u7684\u8bba\u6587\u8c8c\u4f3c\u8bf4\u660e\u53ea\u6709attention\u7684\u6a21\u578b\u4ec0\u4e48\u7528\u90fd\u6ca1\u6709\uff1a<\/strong><\/p><\/blockquote>\n\n\n\n<p><a target=\"_blank\" href=\"https:\/\/arxiv.org\/abs\/2103.03404\" rel=\"noreferrer noopener\">Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth\u200barxiv.org\/abs\/2103.03404<\/a><\/p>\n\n\n\n<ul><li>\u8fd9\u91cc\u7684\u6fc0\u6d3b\u51fd\u6570\u9ed8\u8ba4\u5b9e\u73b0\u4e3a<code>gelu<\/code>\uff08Gaussian Error Linerar Units(GELUS\uff09\uff1a&nbsp;GELU(x)=xP(X&lt;=x)=x\u03a6(x)&nbsp;\uff1b\u5f53\u7136\uff0c\u5b83\u662f\u65e0\u6cd5\u76f4\u63a5\u8ba1\u7b97\u7684\uff0c\u53ef\u4ee5\u7528\u4e00\u4e2a\u5305\u542btanh\u7684\u8868\u8fbe\u5f0f\u8fdb\u884c\u8fd1\u4f3c\uff08\u7565\uff09\u3002<\/li><\/ul>\n\n\n\n<p>\u4f5c\u4e3a\u53c2\u8003\uff08\u56fe\u6e90\u7f51\u7edc\uff09\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic2.zhimg.com\/v2-23d4e3ad622f5b8eb4aaa3a71d183939_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<p><strong>\u81f3\u4e8e\u4e3a\u4ec0\u4e48\u5728transformer\u4e2d\u8981\u7528\u8fd9\u4e2a\u6fc0\u6d3b\u51fd\u6570\u2026\u2026<\/strong><\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic2.zhimg.com\/v2-bdde7321319cc41ef7588fcc70935685_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u770b\u4e86\u4e00\u4e9b\u7814\u7a76\uff0c\u5e94\u8be5\u662f\u8bf4GeLU\u6bd4ReLU\u8fd9\u4e9b\u8868\u73b0\u90fd\u597d\uff0c\u4ee5\u81f3\u4e8e\u540e\u7eed\u7684\u8bed\u8a00\u6a21\u578b\u90fd\u6cbf\u7528\u4e86\u8fd9\u4e00\u6fc0\u6d3b\u51fd\u6570\u3002<\/strong><\/p><\/blockquote>\n\n\n\n<h3>2.2.1.3 BertOutput<\/h3>\n\n\n\n<p>\u5728\u8fd9\u91cc\u53c8\u662f\u4e00\u4e2a\u5168\u8fde\u63a5+dropout+LayerNorm\uff0c\u8fd8\u6709\u4e00\u4e2a\u6b8b\u5dee\u8fde\u63a5residual connect\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states<\/code><\/pre>\n\n\n\n<p><strong>\u8fd9\u91cc\u7684\u64cd\u4f5c\u548cBertSelfOutput\u4e0d\u80fd\u8bf4\u6ca1\u6709\u5173\u7cfb\uff0c\u53ea\u80fd\u8bf4\u4e00\u6a21\u4e00\u6837\u2026\u2026\u975e\u5e38\u5bb9\u6613\u6df7\u6dc6\u7684\u4e24\u4e2a\u7ec4\u4ef6\u3002<\/strong><\/p>\n\n\n\n<p>\u4ee5\u4e0b\u5185\u5bb9\u8fd8\u5305\u542b\u57fa\u4e8eBERT\u7684\u5e94\u7528\u6a21\u578b\uff0c\u4ee5\u53caBERT\u76f8\u5173\u7684\u4f18\u5316\u5668\u548c\u7528\u6cd5\uff0c\u5c06\u5728\u4e0b\u4e00\u7bc7\u6587\u7ae0\u4f5c\u8be6\u7ec6\u4ecb\u7ecd\u3002<\/p>\n\n\n\n<h2>2.2.3 BertPooler<\/h2>\n\n\n\n<p>\u8fd9\u4e00\u5c42\u53ea\u662f\u7b80\u5355\u5730\u53d6\u51fa\u4e86\u53e5\u5b50\u7684\u7b2c\u4e00\u4e2atoken\uff0c\u5373<code>[CLS]<\/code>\u5bf9\u5e94\u7684\u5411\u91cf\uff0c\u7136\u540e\u8fc7\u4e00\u4e2a\u5168\u8fde\u63a5\u5c42\u548c\u4e00\u4e2a\u6fc0\u6d3b\u51fd\u6570\u540e\u8f93\u51fa\uff1a<\/p>\n\n\n\n<p><strong><em>\uff08\u8fd9\u4e00\u90e8\u5206\u662f\u53ef\u9009\u7684\uff0c\u56e0\u4e3apooling\u6709\u5f88\u591a\u4e0d\u540c\u7684\u64cd\u4f5c\uff09<\/em><\/strong><\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        <em># We \"pool\" the model by simply taking the hidden state corresponding<\/em>\n        <em># to the first token.<\/em>\n        first_token_tensor = hidden_states&#91;:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output<\/code><\/pre>\n\n\n\n<h2>Takeaways\u00b7\u5c0f\u7ed3<\/h2>\n\n\n\n<ul><li>\u5728HuggingFace\u5b9e\u73b0\u7684Bert\u6a21\u578b\u4e2d\uff0c\u4f7f\u7528\u4e86\u591a\u79cd\u8282\u7ea6\u663e\u5b58\u7684\u6280\u672f\uff1a<ul><li>gradient checkpoint\uff0c\u4e0d\u4fdd\u7559\u524d\u5411\u4f20\u64ad\u8282\u70b9\uff0c\u53ea\u5728\u7528\u65f6\u8ba1\u7b97\uff1b<\/li><li>apply_chunking_to_forward\uff0c\u6309\u591a\u4e2a\u5c0f\u6279\u91cf\u548c\u4f4e\u7ef4\u5ea6\u8ba1\u7b97FFN\u90e8\u5206\uff1b<\/li><\/ul><\/li><li>BertModel\u5305\u542b\u590d\u6742\u7684\u5c01\u88c5\u548c\u8f83\u591a\u7684\u7ec4\u4ef6\u3002\u4ee5bert-base\u4e3a\u4f8b\uff0c\u4e3b\u8981\u7ec4\u4ef6\u5982\u4e0b\uff1a<ul><li>\u603b\u8ba1<code>Dropout<\/code>\u51fa\u73b0\u4e86<code>1+(1+1+1)x12=37<\/code>\u6b21\uff1b<\/li><li>\u603b\u8ba1<code>LayerNorm<\/code>\u51fa\u73b0\u4e86<code>1+(1+1)x12=25<\/code>\u6b21\uff1b<\/li><li>\u603b\u8ba1<code>dense<\/code>\u5168\u8fde\u63a5\u5c42\u51fa\u73b0\u4e86<code>(1+1+1)x12+1=37<\/code>\u6b21\uff0c\u5e76\u4e0d\u662f\u6bcf\u4e2a<code>dense<\/code>\u90fd\u914d\u4e86\u6fc0\u6d3b\u51fd\u6570\u2026\u2026<\/li><\/ul><\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>BertModel(\n  (embeddings): BertEmbeddings(\n    (word_embeddings): Embedding(30522, 768, padding_idx=0)\n    (position_embeddings): Embedding(512, 768)\n    (token_type_embeddings): Embedding(2, 768)\n    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n    (dropout): Dropout(p=0.1, inplace=False)\n  )\n  (encoder): BertEncoder(\n    (layer): ModuleList(\n      (0): BertLayer(\n        (attention): BertAttention(\n          (self): BertSelfAttention(\n            (query): Linear(in_features=768, out_features=768, bias=True)\n            (key): Linear(in_features=768, out_features=768, bias=True)\n            (value): Linear(in_features=768, out_features=768, bias=True)\n            (dropout): Dropout(p=0.1, inplace=False)\n          )\n          (output): BertSelfOutput(\n            (dense): Linear(in_features=768, out_features=768, bias=True)\n            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n            (dropout): Dropout(p=0.1, inplace=False)\n          )\n        )\n        (intermediate): BertIntermediate(\n          (dense): Linear(in_features=768, out_features=3072, bias=True)\n        )\n        (output): BertOutput(\n          (dense): Linear(in_features=3072, out_features=768, bias=True)\n          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n          (dropout): Dropout(p=0.1, inplace=False)\n        )\n      )\n      # \u5982\u6b64\u91cd\u590d11\u5c42\uff0c\u76f4\u5230\u5927\u53a6\u5d29\u584c\uff1a\uff09\n    )\n  )\n  (pooler): BertPooler(\n    (dense): Linear(in_features=768, out_features=768, bias=True)\n    (activation): Tanh()<\/code><\/pre>\n\n\n\n<h2>3 BERT-based Models<\/h2>\n\n\n\n<p>\u57fa\u4e8eBERT\u7684\u6a21\u578b\u90fd\u5199\u5728<code>\/models\/bert\/modeling_bert.py<\/code>\u91cc\u9762\uff0c\u5305\u62ecBERT\u9884\u8bad\u7ec3\u6a21\u578b\u548cBERT\u5206\u7c7b\u6a21\u578b\uff0cUML\u56fe\u5982\u4e0b\uff1a<\/p>\n\n\n\n<h2><strong>BERT\u6a21\u578b\u4e00\u56fe\u6d41\uff08\u5efa\u8bae\u4fdd\u5b58\u540e\u653e\u5927\u67e5\u770b\uff09\uff1a<\/strong><\/h2>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic1.zhimg.com\/v2-0e126f74d40d2db8bc133bc67f8055b4_r.jpg\" alt=\"\"\/><figcaption>\u753b\u56fe\u5de5\u5177\uff1aPyreverse<\/figcaption><\/figure>\n\n\n\n<p>\u9996\u5148\uff0c\u4ee5\u4e0b\u6240\u6709\u7684\u6a21\u578b\u90fd\u662f\u57fa\u4e8e<code>BertPreTrainedModel<\/code>\u8fd9\u4e00\u62bd\u8c61\u57fa\u7c7b\u7684\uff0c\u800c\u540e\u8005\u5219\u57fa\u4e8e\u4e00\u4e2a\u66f4\u5927\u7684\u57fa\u7c7b<code>PreTrainedModel<\/code>\u3002\u8fd9\u91cc\u6211\u4eec\u5173\u6ce8<code>BertPreTrainedModel<\/code>\u7684\u529f\u80fd\uff1a<\/p>\n\n\n\n<ul><li>\u7528\u4e8e\u521d\u59cb\u5316\u6a21\u578b\u6743\u91cd\uff0c\u540c\u65f6\u7ef4\u62a4\u7ee7\u627f\u81ea<code>PreTrainedModel<\/code>\u7684\u4e00\u4e9b\u6807\u8bb0\u8eab\u4efd\u6216\u8005\u52a0\u8f7d\u6a21\u578b\u65f6\u7684\u7c7b\u53d8\u91cf\u3002<\/li><\/ul>\n\n\n\n<p>\u4e0b\u9762\uff0c\u9996\u5148\u4ece\u9884\u8bad\u7ec3\u6a21\u578b\u5f00\u59cb\u5206\u6790\u3002<\/p>\n\n\n\n<h3>3.1 BertForPreTraining<\/h3>\n\n\n\n<p>\u4f17\u6240\u5468\u77e5\uff0cBERT\u9884\u8bad\u7ec3\u4efb\u52a1\u5305\u62ec\u4e24\u4e2a\uff1a<\/p>\n\n\n\n<ul><li><strong>Masked Language Model\uff08MLM\uff09<\/strong>\uff1a\u5728\u53e5\u5b50\u4e2d\u968f\u673a\u7528<code>[MASK]<\/code>\u66ff\u6362\u4e00\u90e8\u5206\u5355\u8bcd\uff0c\u7136\u540e\u5c06\u53e5\u5b50\u4f20\u5165 BERT \u4e2d\u7f16\u7801\u6bcf\u4e00\u4e2a\u5355\u8bcd\u7684\u4fe1\u606f\uff0c\u6700\u7ec8\u7528<code>[MASK]<\/code>\u7684\u7f16\u7801\u4fe1\u606f\u9884\u6d4b\u8be5\u4f4d\u7f6e\u7684\u6b63\u786e\u5355\u8bcd\uff0c\u8fd9\u4e00\u4efb\u52a1\u65e8\u5728\u8bad\u7ec3\u6a21\u578b\u6839\u636e\u4e0a\u4e0b\u6587\u7406\u89e3\u5355\u8bcd\u7684\u610f\u601d\uff1b<\/li><li><strong>Next Sentence Prediction\uff08NSP\uff09<\/strong>\uff1a\u5c06\u53e5\u5b50\u5bf9A\u548cB\u8f93\u5165BERT\uff0c\u4f7f\u7528<code>[CLS]<\/code>\u7684\u7f16\u7801\u4fe1\u606f\u8fdb\u884c\u9884\u6d4bB\u662f\u5426A\u7684\u4e0b\u4e00\u53e5\uff0c\u8fd9\u4e00\u4efb\u52a1\u65e8\u5728\u8bad\u7ec3\u6a21\u578b\u7406\u89e3\u9884\u6d4b\u53e5\u5b50\u95f4\u7684\u5173\u7cfb\u3002<\/li><\/ul>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic4.zhimg.com\/v2-778b166945e69e7689cccfe7532e74e3_r.jpg\" alt=\"\"\/><figcaption>\u56fe\u6e90\u7f51\u7edc<\/figcaption><\/figure>\n\n\n\n<p>\u800c\u5bf9\u5e94\u5230\u4ee3\u7801\u4e2d\uff0c\u8fd9\u4e00\u878d\u5408\u4e24\u4e2a\u4efb\u52a1\u7684\u6a21\u578b\u5c31\u662f<code>BertForPreTraining<\/code>\uff0c\u5176\u4e2d\u5305\u542b\u4e24\u4e2a\u7ec4\u4ef6\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertForPreTraining(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config)\n        self.cls = BertPreTrainingHeads(config)\n\n        self.init_weights()\n    <em># ...<\/em><\/code><\/pre>\n\n\n\n<p>\u8fd9\u91cc\u7684<code>BertModel<\/code>\u5728\u4e0a\u4e00\u7bc7\u6587\u7ae0\u4e2d\u5df2\u7ecf\u8be6\u7ec6\u4ecb\u7ecd\u4e86\uff08\u6ce8\u610f\uff0c\u8fd9\u91cc\u8bbe\u7f6e\u7684\u662f\u9ed8\u8ba4<code>add_pooling_layer=True<\/code>\uff0c\u5373\u4f1a\u63d0\u53d6<code>[CLS]<\/code>\u5bf9\u5e94\u7684\u8f93\u51fa\u7528\u4e8eNSP\u4efb\u52a1\uff09\uff0c\u800c<code>BertPreTrainingHeads<\/code>\u5219\u662f\u8d1f\u8d23\u4e24\u4e2a\u4efb\u52a1\u7684\u9884\u6d4b\u6a21\u5757\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertPreTrainingHeads(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\n\n    def forward(self, sequence_output, pooled_output):\n        prediction_scores = self.predictions(sequence_output)\n        seq_relationship_score = self.seq_relationship(pooled_output)\n        return prediction_scores, seq_relationship_score <\/code><\/pre>\n\n\n\n<p>\u53c8\u662f\u4e00\u5c42\u5c01\u88c5\uff1a<code>BertPreTrainingHeads<\/code>\u5305\u88f9\u4e86<code>BertLMPredictionHead<\/code>&nbsp;\u548c\u4e00\u4e2a\u4ee3\u8868NSP\u4efb\u52a1\u7684\u7ebf\u6027\u5c42\u3002\u8fd9\u91cc\u4e0d\u628aNSP\u5bf9\u5e94\u7684\u4efb\u52a1\u4e5f\u5c01\u88c5\u4e00\u4e2a<code>BertXXXPredictionHead<\/code>\uff0c\u4f30\u8ba1\u662f\u56e0\u4e3a\u5b83\u592a\u7b80\u5355\u4e86\uff0c\u6ca1\u6709\u5fc5\u8981\u2026\u2026<\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u5176\u5b9e\u662f\u6709\u5c01\u88c5\u8fd9\u4e2a\u7c7b\u7684\uff0c\u4e0d\u8fc7\u5b83\u53eb\u505a<code>BertOnlyNSPHead<\/code>\uff0c\u5728\u8fd9\u91cc\u7528\u4e0d\u4e0a\u2026\u2026<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u7ee7\u7eed\u4e0b\u63a2<code>BertPreTrainingHeads<\/code>&nbsp;\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        <em># The output weights are the same as the input embeddings, but there is<\/em>\n        <em># an output-only bias for each token.<\/em>\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        <em># Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`<\/em>\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states<\/code><\/pre>\n\n\n\n<p>\u8fd9\u4e2a\u7c7b\u7528\u4e8e\u9884\u6d4b<code>[MASK]<\/code>\u4f4d\u7f6e\u7684\u8f93\u51fa\u5728\u6bcf\u4e2a\u8bcd\u4f5c\u4e3a\u7c7b\u522b\u7684\u5206\u7c7b\u8f93\u51fa\uff0c\u6ce8\u610f\u5230\uff1a<\/p>\n\n\n\n<ul><li>\u8be5\u7c7b\u91cd\u65b0\u521d\u59cb\u5316\u4e86\u4e00\u4e2a\u51680\u5411\u91cf\u4f5c\u4e3a\u9884\u6d4b\u6743\u91cd\u7684bias\uff1b<\/li><li>\u8be5\u7c7b\u7684\u8f93\u51fa\u5f62\u72b6\u4e3a<code>[batch_size, seq_length, vocab_size]<\/code>\uff0c\u5373\u9884\u6d4b\u6bcf\u4e2a\u53e5\u5b50\u6bcf\u4e2a\u8bcd\u662f\u4ec0\u4e48\u7c7b\u522b\u7684\u6982\u7387\u503c\uff08\u6ce8\u610f\u8fd9\u91cc\u6ca1\u6709\u505asoftmax\uff09\uff1b<\/li><li>\u53c8\u4e00\u4e2a\u5c01\u88c5\u7684\u7c7b\uff1a<code>BertPredictionHeadTransform<\/code>\uff0c\u7528\u6765\u5b8c\u6210\u4e00\u4e9b\u7ebf\u6027\u53d8\u6362\uff1a<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN&#91;config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states<\/code><\/pre>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u611f\u89c9\u8fd9\u4e00\u5c42\u53bb\u6389\u4e5f\u884c\uff1f\u8f93\u51fa\u7684\u5f62\u72b6\u4e5f\u6ca1\u6709\u53d1\u751f\u53d8\u5316\u3002\u6211\u4e2a\u4eba\u7684\u7406\u89e3\u662f\u548cPooling\u90a3\u91cc\u505a\u4e00\u4e2a\u5bf9\u79f0\u7684\u64cd\u4f5c\uff0c\u540c\u6837\u8fc7\u4e00\u5c42dense\u518d\u63a5\u5206\u7c7b\u5668\u2026\u2026<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u56de\u5230<code>BertForPreTraining<\/code>\uff0c\u7ee7\u7eed\u770b\u4e24\u5757loss\u662f\u600e\u4e48\u5904\u7406\u7684\u3002\u5b83\u7684\u524d\u5411\u4f20\u64ad\u548c<code>BertModel<\/code>\u7684\u6709\u6240\u4e0d\u540c\uff0c\u591a\u4e86<code>labels<\/code>\u548c<code>next_sentence_label<\/code>&nbsp;\u4e24\u4e2a\u8f93\u5165\uff1a<\/p>\n\n\n\n<ul><li><code><strong>labels<\/strong><\/code>\uff1a\u5f62\u72b6\u4e3a<code>[batch_size, seq_length]<\/code>&nbsp;\uff0c\u4ee3\u8868MLM\u4efb\u52a1\u7684\u6807\u7b7e\uff0c\u6ce8\u610f\u8fd9\u91cc\u5bf9\u4e8e\u539f\u672c\u672a\u88ab\u906e\u76d6\u7684\u8bcd\u8bbe\u7f6e\u4e3a-100\uff0c\u88ab\u906e\u76d6\u8bcd\u624d\u4f1a\u6709\u5b83\u4eec\u5bf9\u5e94\u7684id\uff0c<strong>\u548c\u4efb\u52a1\u8bbe\u7f6e\u662f\u53cd\u8fc7\u6765\u7684<\/strong>\u3002<ul><li>\u4f8b\u5982\uff0c\u539f\u59cb\u53e5\u5b50\u662f<code>I want to [MASK] an apple<\/code>\uff0c\u8fd9\u91cc\u6211\u628a\u5355\u8bcd<code>eat<\/code>\u7ed9\u906e\u4f4f\u4e86\u8f93\u5165\u6a21\u578b\uff0c\u5bf9\u5e94\u7684<code>label<\/code>\u8bbe\u7f6e\u4e3a<code>[-100, -100, -100, \u3010eat\u5bf9\u5e94\u7684id\u3011, -100, -100]<\/code>\uff1b<\/li><li>\u4e3a\u4ec0\u4e48\u8981\u8bbe\u7f6e\u4e3a-100\u800c\u4e0d\u662f\u5176\u4ed6\u6570\uff1f \u56e0\u4e3a<code>torch.nn.CrossEntropyLoss<\/code>\u9ed8\u8ba4\u7684<code>ignore_index=-100<\/code>\uff0c\u4e5f\u5c31\u662f\u8bf4\u5bf9\u4e8e\u6807\u7b7e\u4e3a100\u7684\u7c7b\u522b\u8f93\u5165\u4e0d\u4f1a\u8ba1\u7b97loss\u3002<\/li><\/ul><\/li><li><code><strong>next_sentence_label<\/strong><\/code>\uff1a \u8fd9\u4e00\u4e2a\u8f93\u5165\u5f88\u7b80\u5355\uff0c\u5c31\u662f0\u548c1\u7684\u4e8c\u5206\u7c7b\u6807\u7b7e\u3002<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>    <em># ...<\/em>\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        next_sentence_label=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ): ...<\/code><\/pre>\n\n\n\n<p>OK\uff0c\u63a5\u4e0b\u6765\u4e24\u90e8\u5206loss\u7684\u7ec4\u5408\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>        <em># ...<\/em>\n        total_loss = None\n        if labels is not None and next_sentence_label is not None:\n            loss_fct = CrossEntropyLoss()\n            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))\n            total_loss = masked_lm_loss + next_sentence_loss\n        <em># ...<\/em><\/code><\/pre>\n\n\n\n<p>\u76f4\u63a5\u76f8\u52a0\uff0c\u5c31\u662f\u8fd9\u4e48\u5355\u7eaf\u7684\u7b56\u7565\u3002<\/p>\n\n\n\n<p>\u5f53\u7136\uff0c\u8fd9\u4efd\u4ee3\u7801\u91cc\u9762\u4e5f\u5305\u542b\u4e86\u5bf9\u4e8e\u53ea\u60f3\u5bf9\u5355\u4e2a\u76ee\u6807\u8fdb\u884c\u9884\u8bad\u7ec3\u7684BERT\u6a21\u578b\uff08\u5177\u4f53\u7ec6\u8282\u4e0d\u4f5c\u5c55\u5f00\uff09\uff1a<\/p>\n\n\n\n<ul><li><code><strong>BertForMaskedLM<\/strong><\/code>\uff1a\u53ea\u8fdb\u884cMLM\u4efb\u52a1\u7684\u9884\u8bad\u7ec3\uff1b<ul><li>\u57fa\u4e8e<code>BertOnlyMLMHead<\/code>\uff0c\u800c\u540e\u8005\u4e5f\u662f\u5bf9<code>BertLMPredictionHead<\/code>\u7684\u53e6\u4e00\u5c42\u5c01\u88c5\uff1b<\/li><\/ul><\/li><li><strong><code>BertLMHeadModel<\/code><\/strong>\uff1a\u8fd9\u4e2a\u548c\u4e0a\u4e00\u4e2a\u7684\u533a\u522b\u5728\u4e8e\uff0c\u8fd9\u4e00\u6a21\u578b\u662f<strong>\u4f5c\u4e3adecoder\u8fd0\u884c<\/strong>\u7684\u7248\u672c\uff1b<ul><li>\u540c\u6837\u57fa\u4e8e<code>BertOnlyMLMHead<\/code>\uff1b<\/li><\/ul><\/li><li><code><strong>BertForNextSentencePrediction<\/strong><\/code>\uff1a\u53ea\u8fdb\u884cNSP\u4efb\u52a1\u7684\u9884\u8bad\u7ec3\u3002<ul><li>\u57fa\u4e8e<code>BertOnlyNSPHead<\/code>\uff0c\u5185\u5bb9\u5c31\u662f\u4e00\u4e2a\u7ebf\u6027\u5c42\u2026\u2026<\/li><\/ul><\/li><\/ul>\n\n\n\n<p>\u63a5\u4e0b\u6765\u4ecb\u7ecd\u7684\u662f\u5404\u79cdFine-tune\u6a21\u578b\uff0c\u57fa\u672c\u90fd\u662f\u5206\u7c7b\u4efb\u52a1\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic1.zhimg.com\/v2-d870cb6a4cc1b6f5f7f54cd9f563e468_r.jpg\" alt=\"\"\/><figcaption>\u56fe\u6e90\uff1a\u539f\u59cbBERT\u8bba\u6587\u9644\u5f55<\/figcaption><\/figure>\n\n\n\n<h3>3.2 BertForSequenceClassification<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u6a21\u578b\u7528\u4e8e\u53e5\u5b50\u5206\u7c7b\uff08\u4e5f\u53ef\u4ee5\u662f\u56de\u5f52\uff09\u4efb\u52a1\uff0c\u6bd4\u5982GLUE benchmark\u7684\u5404\u4e2a\u4efb\u52a1\u3002<\/p>\n\n\n\n<ul><li>\u53e5\u5b50\u5206\u7c7b\u7684\u8f93\u5165\u4e3a\u53e5\u5b50\uff08\u5bf9\uff09\uff0c\u8f93\u51fa\u4e3a\u5355\u4e2a\u5206\u7c7b\u6807\u7b7e\u3002<\/li><\/ul>\n\n\n\n<p>\u7ed3\u6784\u4e0a\u5f88\u7b80\u5355\uff0c\u5c31\u662f<code>BertModel<\/code>\uff08\u6709pooling\uff09\u8fc7\u4e00\u4e2adropout\u540e\u63a5\u4e00\u4e2a\u7ebf\u6027\u5c42\u8f93\u51fa\u5206\u7c7b\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class BertForSequenceClassification(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.bert = BertModel(config)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n        <em># ...<\/em><\/code><\/pre>\n\n\n\n<p>\u5728\u524d\u5411\u4f20\u64ad\u65f6\uff0c\u548c\u4e0a\u9762\u9884\u8bad\u7ec3\u6a21\u578b\u4e00\u6837\u9700\u8981\u4f20\u5165<code>labels<\/code>\u8f93\u5165\u3002<\/p>\n\n\n\n<ul><li>\u5982\u679c\u521d\u59cb\u5316\u7684<code>num_labels=1<\/code>\uff0c\u90a3\u4e48\u5c31\u9ed8\u8ba4\u4e3a\u56de\u5f52\u4efb\u52a1\uff0c\u4f7f\u7528MSELoss\uff1b<\/li><li>\u5426\u5219\u8ba4\u4e3a\u662f\u5206\u7c7b\u4efb\u52a1\u3002<\/li><\/ul>\n\n\n\n<h3>3.3 BertForMultipleChoice<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u6a21\u578b\u7528\u4e8e\u591a\u9879\u9009\u62e9\uff0c\u5982RocStories\/SWAG\u4efb\u52a1\u3002<\/p>\n\n\n\n<ul><li>\u591a\u9879\u9009\u62e9\u4efb\u52a1\u7684\u8f93\u5165\u4e3a\u4e00\u7ec4\u5206\u6b21\u8f93\u5165\u7684\u53e5\u5b50\uff0c\u8f93\u51fa\u4e3a\u9009\u62e9\u67d0\u4e00\u53e5\u5b50\u7684\u5355\u4e2a\u6807\u7b7e\u3002<\/li><\/ul>\n\n\n\n<p>\u7ed3\u6784\u4e0a\u4e0e\u53e5\u5b50\u5206\u7c7b\u76f8\u4f3c\uff0c\u53ea\u4e0d\u8fc7\u7ebf\u6027\u5c42\u8f93\u51fa\u7ef4\u5ea6\u4e3a1\uff0c\u5373\u6bcf\u6b21\u9700\u8981\u5c06\u6bcf\u4e2a\u6837\u672c\u7684\u591a\u4e2a\u53e5\u5b50\u7684\u8f93\u51fa\u62fc\u63a5\u8d77\u6765\u4f5c\u4e3a\u6bcf\u4e2a\u6837\u672c\u7684\u9884\u6d4b\u5206\u6570\u3002<\/p>\n\n\n\n<ul><li>\u5b9e\u9645\u4e0a\uff0c\u5177\u4f53\u64cd\u4f5c\u65f6\u662f\u628a\u6bcf\u4e2abatch\u7684\u591a\u4e2a\u53e5\u5b50\u4e00\u540c\u653e\u5165\u7684\uff0c\u6240\u4ee5\u4e00\u6b21\u5904\u7406\u7684\u8f93\u5165\u4e3a<code>[batch_size, num_choices]<\/code>\u6570\u91cf\u7684\u53e5\u5b50\uff0c\u56e0\u6b64\u76f8\u540cbatch\u5927\u5c0f\u65f6\uff0c\u6bd4\u53e5\u5b50\u5206\u7c7b\u7b49\u4efb\u52a1\u9700\u8981\u66f4\u591a\u7684\u663e\u5b58\uff0c\u5728\u8bad\u7ec3\u65f6\u9700\u8981\u5c0f\u5fc3\u3002<\/li><\/ul>\n\n\n\n<h3>3.4 BertForTokenClassification<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u6a21\u578b\u7528\u4e8e\u5e8f\u5217\u6807\u6ce8\uff08\u8bcd\u5206\u7c7b\uff09\uff0c\u5982NER\u4efb\u52a1\u3002<\/p>\n\n\n\n<ul><li>\u5e8f\u5217\u6807\u6ce8\u4efb\u52a1\u7684\u8f93\u5165\u4e3a\u5355\u4e2a\u53e5\u5b50\u6587\u672c\uff0c\u8f93\u51fa\u4e3a\u6bcf\u4e2atoken\u5bf9\u5e94\u7684\u7c7b\u522b\u6807\u7b7e\u3002<\/li><\/ul>\n\n\n\n<p>\u7531\u4e8e\u9700\u8981\u7528\u5230\u6bcf\u4e2atoken\u5bf9\u5e94\u7684\u8f93\u51fa\u800c\u4e0d\u53ea\u662f\u67d0\u51e0\u4e2a\uff0c\u6240\u4ee5\u8fd9\u91cc\u7684<code>BertModel<\/code>\u4e0d\u7528\u52a0\u5165pooling\u5c42\uff1b<\/p>\n\n\n\n<ul><li>\u540c\u65f6\uff0c\u8fd9\u91cc\u5c06<code>_keys_to_ignore_on_load_unexpected<\/code>\u8fd9\u4e00\u4e2a\u7c7b\u53c2\u6570\u8bbe\u7f6e\u4e3a<code>[r\"pooler\"]<\/code>\uff0c\u4e5f\u5c31\u662f\u5728\u52a0\u8f7d\u6a21\u578b\u65f6\u5bf9\u4e8e\u51fa\u73b0\u4e0d\u9700\u8981\u7684\u6743\u91cd\u4e0d\u53d1\u751f\u62a5\u9519\u3002<\/li><\/ul>\n\n\n\n<h3>3.5 BertForQuestionAnswering<\/h3>\n\n\n\n<p>\u8fd9\u4e00\u6a21\u578b\u7528\u4e8e\u89e3\u51b3\u95ee\u7b54\u4efb\u52a1\uff0c\u4f8b\u5982SQuAD\u4efb\u52a1\u3002<\/p>\n\n\n\n<ul><li>\u95ee\u7b54\u4efb\u52a1\u7684\u8f93\u5165\u4e3a\u95ee\u9898+\uff08\u5bf9\u4e8eBERT\u53ea\u80fd\u662f\u4e00\u4e2a\uff09\u56de\u7b54\u7ec4\u6210\u7684\u53e5\u5b50\u5bf9\uff0c\u8f93\u51fa\u4e3a\u8d77\u59cb\u4f4d\u7f6e\u548c\u7ed3\u675f\u4f4d\u7f6e\u7528\u4e8e\u6807\u51fa\u56de\u7b54\u4e2d\u7684\u5177\u4f53\u6587\u672c\u3002<\/li><\/ul>\n\n\n\n<p>\u8fd9\u91cc\u9700\u8981\u4e24\u4e2a\u8f93\u51fa\uff0c\u5373\u5bf9\u8d77\u59cb\u4f4d\u7f6e\u7684\u9884\u6d4b\u548c\u5bf9\u7ed3\u675f\u4f4d\u7f6e\u7684\u9884\u6d4b\uff0c\u4e24\u4e2a\u8f93\u51fa\u7684\u957f\u5ea6\u90fd\u548c\u53e5\u5b50\u957f\u5ea6\u4e00\u6837\uff0c\u4ece\u5176\u4e2d\u6311\u51fa\u6700\u5927\u7684\u9884\u6d4b\u503c\u5bf9\u5e94\u7684\u4e0b\u6807\u4f5c\u4e3a\u9884\u6d4b\u7684\u4f4d\u7f6e\u3002<\/p>\n\n\n\n<ul><li>\u5bf9\u8d85\u51fa\u53e5\u5b50\u957f\u5ea6\u7684\u975e\u6cd5label\uff0c\u4f1a\u5c06\u5176\u538b\u7f29\uff08<code>torch.clamp_<\/code>\uff09\u5230\u5408\u7406\u8303\u56f4\u3002<\/li><\/ul>\n\n\n\n<blockquote class=\"wp-block-quote\"><p>\u4f5c\u4e3a\u4e00\u4e2a\u8fdf\u5230\u7684\u8865\u5145\uff0c\u8fd9\u91cc\u7a0d\u5fae\u4ecb\u7ecd\u4e00\u4e0b<code>ModelOutput<\/code>\u8fd9\u4e2a\u7c7b\u3002\u5b83\u4f5c\u4e3a\u4e0a\u8ff0\u5404\u4e2a\u6a21\u578b\u8f93\u51fa\u5305\u88c5\u7684\u57fa\u7c7b\uff0c\u540c\u65f6\u652f\u6301\u5b57\u5178\u5f0f\u7684\u5b58\u53d6\u548c\u4e0b\u6807\u987a\u5e8f\u7684\u8bbf\u95ee\uff0c\u7ee7\u627f\u81eapython\u539f\u751f\u7684<code>OrderedDict<\/code>&nbsp;\u7c7b\u3002<\/p><\/blockquote>\n\n\n\n<hr class=\"wp-block-separator\"\/>\n\n\n\n<p>\u4ee5\u4e0a\u5c31\u662f\u5173\u4e8eBERT\u6e90\u7801\u7684\u4ecb\u7ecd\uff0c\u4e0b\u9762\u4ecb\u7ecd\u4e00\u4e9b\u5173\u4e8eBERT\u6a21\u578b\u5b9e\u7528\u7684\u8bad\u7ec3\u7ec6\u8282\u3002<\/p>\n\n\n\n<h2>4 BERT\u8bad\u7ec3\u548c\u4f18\u5316<\/h2>\n\n\n\n<h3>4.1 Pre-Training<\/h3>\n\n\n\n<p>\u9884\u8bad\u7ec3\u9636\u6bb5\uff0c\u9664\u4e86\u4f17\u6240\u5468\u77e5\u768415%\u300180%mask\u6bd4\u4f8b\uff0c\u6709\u4e00\u4e2a\u503c\u5f97\u6ce8\u610f\u7684\u5730\u65b9\u5c31\u662f\u53c2\u6570\u5171\u4eab\u3002<\/p>\n\n\n\n<p>\u4e0d\u6b62BERT\uff0c\u6240\u6709huggingface\u5b9e\u73b0\u7684PLM\u7684word embedding\u548cmasked language model\u7684\u9884\u6d4b\u6743\u91cd\u5728\u521d\u59cb\u5316\u8fc7\u7a0b\u4e2d\u90fd\u662f\u5171\u4eab\u7684\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):\n    <em># ...<\/em>\n    def tie_weights(self):\n        \"\"\"\n        Tie the weights between the input embeddings and the output embeddings.\n\n        If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning\n        the weights instead.\n        \"\"\"\n        output_embeddings = self.get_output_embeddings()\n        if output_embeddings is not None and self.config.tie_word_embeddings:\n            self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())\n\n        if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:\n            if hasattr(self, self.base_model_prefix):\n                self = getattr(self, self.base_model_prefix)\n            self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)\n    <em># ...<\/em><\/code><\/pre>\n\n\n\n<p>\u81f3\u4e8e\u4e3a\u4ec0\u4e48\uff0c\u5e94\u8be5\u662f\u56e0\u4e3aword_embedding\u548cprediction\u6743\u91cd\u592a\u5927\u4e86\uff0c\u4ee5bert-base\u4e3a\u4f8b\uff0c\u5176\u5c3a\u5bf8\u4e3a<code>(30522, 768)<\/code>\uff0c\u964d\u4f4e\u8bad\u7ec3\u96be\u5ea6\u3002<\/p>\n\n\n\n<h3>4.2 Fine-Tuning<\/h3>\n\n\n\n<p>\u5fae\u8c03\u4e5f\u5c31\u662f\u4e0b\u6e38\u4efb\u52a1\u9636\u6bb5\uff0c\u4e5f\u6709\u4e24\u4e2a\u503c\u5f97\u6ce8\u610f\u7684\u5730\u65b9\u3002<\/p>\n\n\n\n<h3>4.2.1 AdamW<\/h3>\n\n\n\n<p>\u9996\u5148\u4ecb\u7ecd\u4e00\u4e0bBERT\u7684\u4f18\u5316\u5668\uff1a<strong>AdamW<\/strong>\uff08AdamWeightDecayOptimizer\uff09\u3002<\/p>\n\n\n\n<p>\u8fd9\u4e00\u4f18\u5316\u5668\u6765\u81eaICLR 2017\u7684Best Paper\uff1a\u300aFixing Weight Decay Regularization in Adam\u300b\u4e2d\u63d0\u51fa\u7684\u4e00\u79cd\u7528\u4e8e\u4fee\u590dAdam\u7684\u6743\u91cd\u8870\u51cf\u9519\u8bef\u7684\u65b0\u65b9\u6cd5\u3002\u8bba\u6587\u6307\u51fa\uff0cL2\u6b63\u5219\u5316\u548c\u6743\u91cd\u8870\u51cf\u5728\u5927\u90e8\u5206\u60c5\u51b5\u4e0b\u5e76\u4e0d\u7b49\u4ef7\uff0c\u53ea\u5728SGD\u4f18\u5316\u7684\u60c5\u51b5\u4e0b\u662f\u7b49\u4ef7\u7684\uff1b\u800c\u5927\u591a\u6570\u6846\u67b6\u4e2d\u5bf9\u4e8eAdam+L2\u6b63\u5219\u4f7f\u7528\u7684\u662f\u6743\u91cd\u8870\u51cf\u7684\u65b9\u5f0f\uff0c\u4e24\u8005\u4e0d\u80fd\u6df7\u4e3a\u4e00\u8c08\u3002<\/p>\n\n\n\n<p>AdamW\u662f\u5728Adam+L2\u6b63\u5219\u5316\u7684\u57fa\u7840\u4e0a\u8fdb\u884c\u6539\u8fdb\u7684\u7b97\u6cd5\uff0c\u4e0e\u4e00\u822c\u7684Adam+L2\u7684\u533a\u522b\u5982\u4e0b\uff1a<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img src=\"https:\/\/pic1.zhimg.com\/v2-6338c0014e170d44d02869af8788b88c_r.jpg\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u5173\u4e8eAdamW\u7684\u5206\u6790\u53ef\u4ee5\u53c2\u8003\uff1a<a target=\"_blank\" href=\"https:\/\/www.fast.ai\/2018\/07\/02\/adam-weight-decay\/\" rel=\"noreferrer noopener\">AdamW and Super-convergence is now the fastest way to train neural nets\u200bwww.fast.ai\/2018\/07\/02\/adam-weight-decay\/<\/a><a target=\"_blank\" href=\"https:\/\/zhuanlan.zhihu.com\/p\/63982470\" rel=\"noreferrer noopener\">paperplanet\uff1a\u90fd9102\u5e74\u4e86\uff0c\u522b\u518d\u7528Adam + L2 regularization\u4e861183 \u8d5e\u540c \u00b7 34 \u8bc4\u8bba\u6587\u7ae0<\/a><a target=\"_blank\" href=\"https:\/\/www.zhihu.com\/question\/67335251\/answer\/262989932\" rel=\"noreferrer noopener\">ICLR 2018 \u6709\u4ec0\u4e48\u503c\u5f97\u5173\u6ce8\u7684\u4eae\u70b9\uff1f610 \u8d5e\u540c \u00b7 21 \u8bc4\u8bba\u56de\u7b54<\/a><\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8bdd\u8bf4\uff0c\u300aSTABLE WEIGHT DECAY REGULARIZATION\u300b\u8fd9\u7bc7\u597d\u50cf\u5410\u69fdAdamW\u7684Weight Decay\u5b9e\u73b0\u8fd8\u662f\u6709\u95ee\u9898\u2026\u2026\u6709\u7a7a\u6574\u6574\u4f18\u5316\u5668\u76f8\u5173\u7684\u5185\u5bb9\u3002<\/strong><\/p><\/blockquote>\n\n\n\n<p>\u901a\u5e38\uff0c\u6211\u4eec\u4f1a\u9009\u62e9\u6a21\u578b\u7684weight\u90e8\u5206\u53c2\u4e0edecay\u8fc7\u7a0b\uff0c\u800c\u53e6\u4e00\u90e8\u5206\uff08\u5305\u62ecLayerNorm\u7684weight\uff09\u4e0d\u53c2\u4e0e\uff08\u4ee3\u7801\u6700\u521d\u6765\u6e90\u5e94\u8be5\u662fHuggingface\u7684\u793a\u4f8b\uff09\uff1a<\/p>\n\n\n\n<blockquote class=\"wp-block-quote\"><p><strong>\u8865\u5145\uff1a\u5173\u4e8e\u8fd9\u4e48\u505a\u7684\u7406\u7531\uff0c\u6211\u6682\u65f6\u6ca1\u6709\u627e\u5230\u5408\u7406\u7684\u89e3\u7b54\uff0c\u4f46\u662f\u627e\u5230\u4e86\u4e00\u4e9b\u76f8\u5173\u7684\u8ba8\u8bba<\/strong><\/p><\/blockquote>\n\n\n\n<figure class=\"wp-block-embed\"><div class=\"wp-block-embed__wrapper\">\nhttps:\/\/forums.fast.ai\/t\/is-weight-decay-applied-to-the-bias-term\/73212\/4\u200bforums.fast.ai\/t\/is-weight-decay-applied-to-the-bias-term\/73212\/4\n<\/div><\/figure>\n\n\n\n<pre class=\"wp-block-code\"><code>    <em># model: a Bert-based-model object<\/em>\n    <em># learning_rate: default 2e-5 for text classification<\/em>\n    param_optimizer = list(model.named_parameters())\n    no_decay = &#91;'bias', 'LayerNorm.bias', 'LayerNorm.weight']\n    optimizer_grouped_parameters = &#91;\n        {'params': &#91;p for n, p in param_optimizer if not any(\n            nd in n for nd in no_decay)], 'weight_decay': 0.01},\n        {'params': &#91;p for n, p in param_optimizer if any(\n            nd in n for nd in no_decay)], 'weight_decay': 0.0}\n    ]\n    optimizer = AdamW(optimizer_grouped_parameters,\n                      lr=learning_rate)\n    <em># ...<\/em><\/code><\/pre>\n\n\n\n<h3>4.2.2 Warmup<\/h3>\n\n\n\n<p>BERT\u7684\u8bad\u7ec3\u4e2d\u53e6\u4e00\u4e2a\u7279\u70b9\u5728\u4e8eWarmup\uff0c\u5176\u542b\u4e49\u4e3a\uff1a<\/p>\n\n\n\n<ul><li>\u5728\u8bad\u7ec3\u521d\u671f\u4f7f\u7528\u8f83\u5c0f\u7684\u5b66\u4e60\u7387\uff08\u4ece0\u5f00\u59cb\uff09\uff0c\u5728\u4e00\u5b9a\u6b65\u6570\uff08\u6bd4\u59821000\u6b65\uff09\u5185\u9010\u6e10\u63d0\u9ad8\u5230\u6b63\u5e38\u5927\u5c0f\uff08\u6bd4\u5982\u4e0a\u9762\u76842e-5\uff09\uff0c\u907f\u514d\u6a21\u578b\u8fc7\u65e9\u8fdb\u5165\u5c40\u90e8\u6700\u4f18\u800c\u8fc7\u62df\u5408\uff1b<\/li><li>\u5728\u8bad\u7ec3\u540e\u671f\u518d\u6162\u6162\u5c06\u5b66\u4e60\u7387\u964d\u4f4e\u52300\uff0c\u907f\u514d\u540e\u671f\u8bad\u7ec3\u8fd8\u51fa\u73b0\u8f83\u5927\u7684\u53c2\u6570\u53d8\u5316\u3002<\/li><\/ul>\n\n\n\n<p>\u5728Huggingface\u7684\u5b9e\u73b0\u4e2d\uff0c\u53ef\u4ee5\u4f7f\u7528\u591a\u79cdwarmup\u7b56\u7565\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>TYPE_TO_SCHEDULER_FUNCTION = {\n    SchedulerType.LINEAR: get_linear_schedule_with_warmup,\n    SchedulerType.COSINE: get_cosine_schedule_with_warmup,\n    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,\n    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,\n    SchedulerType.CONSTANT: get_constant_schedule,\n    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,\n}<\/code><\/pre>\n\n\n\n<p>\u5177\u4f53\u800c\u8a00\uff1a<\/p>\n\n\n\n<ul><li>CONSTANT\uff1a\u4fdd\u6301\u56fa\u5b9a\u5b66\u4e60\u7387\u4e0d\u53d8\uff1b<\/li><li>CONSTANT_WITH_WARMUP\uff1a\u5728\u6bcf\u4e00\u4e2astep\u4e2d\u7ebf\u6027\u8c03\u6574\u5b66\u4e60\u7387\uff1b<\/li><li>LINEAR\uff1a\u4e0a\u6587\u63d0\u5230\u7684\u4e24\u6bb5\u5f0f\u8c03\u6574\uff1b<\/li><li>COSINE\uff1a\u548c\u4e24\u6bb5\u5f0f\u8c03\u6574\u7c7b\u4f3c\uff0c\u53ea\u4e0d\u8fc7\u91c7\u7528\u7684\u662f\u4e09\u89d2\u51fd\u6570\u5f0f\u7684\u66f2\u7ebf\u8c03\u6574\uff1b<\/li><li>COSINE_WITH_RESTARTS\uff1a\u8bad\u7ec3\u4e2d\u5c06\u4e0a\u9762COSINE\u7684\u8c03\u6574\u91cd\u590dn\u6b21\uff1b<\/li><li>POLYNOMIAL\uff1a\u6309\u6307\u6570\u66f2\u7ebf\u8fdb\u884c\u4e24\u6bb5\u5f0f\u8c03\u6574\u3002<\/li><\/ul>\n\n\n\n<p>\u5177\u4f53\u4f7f\u7528\u53c2\u8003<code>transformers\/optimization.py<\/code>\uff1a<\/p>\n\n\n\n<ul><li>\u6700\u5e38\u7528\u7684\u8fd8\u662f<code>get_linear_scheduler_with_warmup<\/code>\u5373\u7ebf\u6027\u4e24\u6bb5\u5f0f\u8c03\u6574\u5b66\u4e60\u7387\u7684\u65b9\u6848\u2026\u2026<\/li><\/ul>\n\n\n\n<pre class=\"wp-block-code\"><code>def get_scheduler(\n    name: Union&#91;str, SchedulerType],\n    optimizer: Optimizer,\n    num_warmup_steps: Optional&#91;int] = None,\n    num_training_steps: Optional&#91;int] = None,\n): ...<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>\u6458\u81ea\uff1a https:\/\/zhuanlan.zhihu.com\/p\/360988428 \u4f17\u6240\u5468\u77e5\uff0cBERT\u6a21\u578b\u81ea &hellip; <a href=\"http:\/\/139.9.1.231\/index.php\/2022\/08\/13\/huggingface-transformers-bert\/\" class=\"more-link\">\u7ee7\u7eed\u9605\u8bfb<span class=\"screen-reader-text\">HuggingFace Transformers &#8212;-BERT \u6e90\u7801<\/span><\/a><\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":[],"categories":[17,21,4],"tags":[],"_links":{"self":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/5780"}],"collection":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/comments?post=5780"}],"version-history":[{"count":1,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/5780\/revisions"}],"predecessor-version":[{"id":5781,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/posts\/5780\/revisions\/5781"}],"wp:attachment":[{"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/media?parent=5780"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/categories?post=5780"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/139.9.1.231\/index.php\/wp-json\/wp\/v2\/tags?post=5780"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}