1. import torch
  2. from torch import nn
  3. from d2l import torch as d2l
  1. def get_tokens_and_segments(tokens_a, tokens_b=None):
  2. """获取输入序列的词元及其片段索引。"""
  3. tokens = ['<cls>'] + tokens_a + ['<sep>']
  4. # 0和1分别标记片段A和B
  5. segments = [0] * (len(tokens_a) + 2)
  6. if tokens_b is not None:
  7. tokens += tokens_b + ['<sep>']
  8. segments += [1] * (len(tokens_b) + 1)
  9. return tokens, segments
  1. class BERTEncoder(nn.Module):
  2. """BERT encoder."""
  3. def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
  4. ffn_num_hiddens, num_heads, num_layers, dropout,
  5. max_len=1000, key_size=768, query_size=768, value_size=768,
  6. **kwargs):
  7. super(BERTEncoder, self).__init__(**kwargs)
  8. self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
  9. self.segment_embedding = nn.Embedding(2, num_hiddens)
  10. self.blks = nn.Sequential()
  11. for i in range(num_layers):
  12. self.blks.add_module(f"{i}", d2l.EncoderBlock(
  13. key_size, query_size, value_size, num_hiddens, norm_shape,
  14. ffn_num_input, ffn_num_hiddens, num_heads, dropout, True))
  15. # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入参数
  16. self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
  17. num_hiddens))
  18. def forward(self, tokens, segments, valid_lens):
  19. # 在以下代码段中,`X`的形状保持不变:(批量大小,最大序列长度,`num_hiddens`)
  20. X = self.token_embedding(tokens) + self.segment_embedding(segments)
  21. X = X + self.pos_embedding.data[:, :X.shape[1], :]
  22. for blk in self.blks:
  23. X = blk(X, valid_lens)
  24. return X
  1. vocab_size, num_hiddens, ffn_num_hiddens, num_heads = 10000, 768, 1024, 4
  2. norm_shape, ffn_num_input, num_layers, dropout = [768], 768, 2, 0.2
  3. encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input,
  4. ffn_num_hiddens, num_heads, num_layers, dropout)
  1. tokens = torch.randint(0, vocab_size, (2, 8))
  2. segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]])
  3. encoded_X = encoder(tokens, segments, None)
  4. encoded_X.shape
  1. torch.Size([2, 8, 768])
  1. class MaskLM(nn.Module):
  2. """BERT的遮蔽语言模型任务"""
  3. def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
  4. super(MaskLM, self).__init__(**kwargs)
  5. self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
  6. nn.ReLU(),
  7. nn.LayerNorm(num_hiddens),
  8. nn.Linear(num_hiddens, vocab_size))
  9. def forward(self, X, pred_positions):
  10. num_pred_positions = pred_positions.shape[1]
  11. pred_positions = pred_positions.reshape(-1)
  12. batch_size = X.shape[0]
  13. batch_idx = torch.arange(0, batch_size)
  14. # 假设`batch_size=2,`num_pred_positions`=3
  15. # 那么`batch_idx`是`np.array([0,0,0,1,1])`
  16. batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
  17. masked_X = X[batch_idx, pred_positions]
  18. masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
  19. mlm_Y_hat = self.mlp(masked_X)
  20. return mlm_Y_hat
  1. mlm = MaskLM(vocab_size, num_hiddens)
  2. mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
  3. mlm_Y_hat = mlm(encoded_X, mlm_positions)
  4. mlm_Y_hat.shape
  1. torch.Size([2, 3, 10000])
  1. mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
  2. loss = nn.CrossEntropyLoss(reduction='none')
  3. mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
  4. mlm_l.shape
torch.Size([6])
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # `X`的形状: (batch size, `num_hiddens`)
        return self.output(X)
# 默认情况下,PyTorch不会像mxnet中那样展平张量
# 如果flatten=True,则除第一个输入数据轴外,所有输入数据轴都折叠在一起
encoded_X = torch.flatten(encoded_X, start_dim=1)
# NSP的输入形状: (batch size, `num_hiddens`)
nsp = NextSentencePred(encoded_X.shape[-1])
nsp_Y_hat = nsp(encoded_X)
nsp_Y_hat.shape
torch.Size([2, 2])
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l.shape
torch.Size([2])

14.8.6. 把所有的东西放在一起

class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层。0是“<cls>”标记的索引。
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat