B 站视频讲解
本文主要介绍一下如何使用 PyTorch 复现 BERT。请先花上 10 分钟阅读我的这篇文章 BERT 详解(附带 ELMo、GPT 介绍),再来看本文,方能达到醍醐灌顶,事半功倍的效果

准备数据集

这里我并没有用什么大型的数据集,而是手动输入了两个人的对话,主要是为了降低代码阅读难度,我希望读者能更关注模型实现的部分

  1. #导入相关的库
  2. import re
  3. import math
  4. import torch
  5. import numpy as np
  6. from random import *
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. import torch.utils.data as Data
  1. #需要训练的文本
  2. TEXT = (
  3. 'Hello, how are you? I am Romeo.\n' # R
  4. 'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
  5. 'Nice meet you too. How are you today?\n' # R
  6. 'Great. My baseball team won the competition.\n' # J
  7. 'Oh Congratulations, Juliet\n' # R
  8. 'Thank you Romeo\n' # J
  9. 'Where are you going today?\n' # R
  10. 'I am going shopping. What about you?\n' # J
  11. 'I am going to visit my grandmother. she is not very well' # R
  12. )
  13. #相关的参数
  14. # BERT Parameters
  15. MAX_LEN = 30
  16. BATCH_SIZE = 6
  17. MAX_PRED = 5 # max tokens of prediction
  18. NUM_LAYERS = 6
  19. NUM_HEADS = 12
  20. EMBED_DIM = 768
  21. FF_DIM = 768*4 # 4*d_model, FeedForward dimension
  22. KEY_DIM = VALUE_DIM = 64 # dimension of K(=Q), V
  23. NUM_SEGMENTS = 2
  • maxlen表示同一个 batch 中的所有句子都由 30 个 token 组成,不够的补 PAD(这里我实现的方式比较粗暴,直接固定所有 batch 中的所有句子都为 30)
  • max_pred表示最多需要预测多少个单词,即 BERT 中的完形填空任务
  • n_layers表示 Encoder Layer 的数量
  • d_model表示 Token Embeddings、Segment Embeddings、Position Embeddings 的维度
  • d_ff表示 Encoder Layer 中全连接层的维度
  • n_segments表示 Decoder input 由几句话组成

    数据预处理

    数据预处理部分,我们需要根据概率随机 make 或者替换(以下统称 mask)一句话中 15% 的 token,还需要拼接任意两句话

    1. def make_data(text,max_len,batch_size,max_pred):
    2. sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
    3. word_list = list(set(" ".join(sentences).split()))
    4. word2idx={'[PAD]':0,'[CLS]':1,'[SEP]':2,'[MASK]':3}
    5. for i,w in enumerate(word_list):
    6. word2idx[w] = i + 4
    7. idx2word = {i:w for i,w in enumerate(word2idx)}
    8. vocab_size = len(word2idx)
    9. token_list = []
    10. for sentence in sentences:
    11. sentence=[word2idx[w] for w in sentence.split()]
    12. token_list.append(sentence)
    13. batch = []
    14. positive = negative = 0
    15. while positive != batch_size/2 or negative != batch_size/2:
    16. tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
    17. tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
    18. input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
    19. segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
    20. # MASK LM
    21. n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
    22. cand_maked_pos = [i for i, token in enumerate(input_ids)
    23. if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
    24. shuffle(cand_maked_pos)
    25. masked_tokens, masked_pos = [], []
    26. for pos in cand_maked_pos[:n_pred]:
    27. masked_pos.append(pos)
    28. masked_tokens.append(input_ids[pos])
    29. if random() < 0.8: # 80%
    30. input_ids[pos] = word2idx['[MASK]'] # make mask
    31. elif random() > 0.9: # 10%
    32. index = randint(0, vocab_size - 1) # random index in vocabulary
    33. while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
    34. index = randint(0, vocab_size - 1)
    35. input_ids[pos] = index # replace
    36. # Zero Paddings
    37. n_pad = max_len - len(input_ids)
    38. input_ids.extend([0] * n_pad)
    39. segment_ids.extend([0] * n_pad)
    40. # Zero Padding (100% - 15%) tokens
    41. if max_pred > n_pred:
    42. n_pad = max_pred - n_pred
    43. masked_tokens.extend([0] * n_pad)
    44. masked_pos.extend([0] * n_pad)
    45. if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
    46. batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
    47. positive += 1
    48. elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
    49. batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
    50. negative += 1
    51. return batch,vocab_size,idx2word
    52. # Proprecessing Finished

    上述代码中,positive变量代表两句话是连续的个数,negative代表两句话不是连续的个数,我们需要做到在一个 batch 中,这两个样本的比例为 1:1。随机选取的两句话是否连续,只要通过判断tokens_a_index + 1 == tokens_b_index即可
    然后是随机 mask 一些 token,n_pred变量代表的是即将 mask 的 token 数量,cand_maked_pos代表的是有哪些位置是候选的、可以 mask 的(因为像 [SEP],[CLS] 这些不能做 mask,没有意义),最后shuffle()一下,然后根据random()的值选择是替换为[MASK]还是替换为其它的 token
    接下来会做两个 Zero Padding,第一个是为了补齐句子的长度,使得一个 batch 中的句子都是相同长度。第二个是为了补齐 mask 的数量,因为不同句子长度,会导致不同数量的单词进行 mask,我们需要保证同一个 batch 中,mask 的数量(必须)是相同的,所以也需要在后面补一些没有意义的东西,比方说[0]
    以上就是整个数据预处理的部分

模型构建

模型结构主要采用了 Transformer 的 Encoder,所以这里我不再多赘述,可以直接看我的这篇文章 Transformer 的 PyTorch 实现,以及 B 站视频讲解

batch,VOCAB_SIZE,idx2word=make_data(TEXT,MAX_LEN,BATCH_SIZE,MAX_PRED)
batch[5][1]

input_ids, segment_ids, masked_tokens, masked_pos, isNext=zip(*batch)
input_ids = torch.LongTensor(input_ids)
segment_ids = torch.LongTensor(segment_ids)
masked_tokens = torch.LongTensor(masked_tokens)
masked_pos = torch.LongTensor(masked_pos)
isNext = torch.LongTensor(isNext)

#定义数据加载器
class MyDataSet(Data.Dataset):
    def __init__(self,input_ids,segment_ids,masked_tokens,masked_pos,isNext):
        super(MyDataSet,self).__init__()
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
        self.masked_pos = masked_pos
        self.isNext = isNext

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self,idx):
        return self.input_ids[idx],self.segment_ids[idx],self.masked_tokens[idx],self.masked_pos[idx],self.isNext[idx]

loader=Data.DataLoader(MyDataSet(input_ids,segment_ids,masked_tokens,masked_pos,isNext),batch_size=BATCH_SIZE,shuffle=True)

#定义attn_mask函数
def get_attn_pad_mask(seq_q,seq_k):
    '''
    seq_q:[batch_size,seq_len]
    seq_k:[batch_size,seq_len]
    '''
    batch_size,seq_len=seq_q.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size,seq_len,seq_len)

#定义激活函数
def gelu(x):
    """
      Implementation of the gelu activation function.
      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
      Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

#定义Embedding层,获取输入
class Embedding(nn.Module):
    def __init__(self,vocab_size,max_len,num_segments,embed_dim):
        super(Embedding,self).__init__()
        self.input_embed = nn.Embedding(vocab_size,embed_dim)
        self.pos_embed = nn.Embedding(max_len,embed_dim)
        self.seg_embed = nn.Embedding(num_segments,embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self,input,segment):
        '''
        input:[batch_size,max_len]
        segment:[batch_size,max_len]
        '''
        #input_embed:[batch_size,max_len,embed_dim]
        input_embed = self.input_embed(input)

        seq_len = input.size(1)
        pos_embed = np.arange(seq_len)
        pos_embed = torch.LongTensor(pos_embed)
        #[batch_size,max_len]
        pos_embed = pos_embed.unsqueeze(0).expand_as(input)
        #[batch_size,max_len,embed_dim]
        pos_embed = self.pos_embed(pos_embed)

        seg_embed = self.seg_embed(segment)

        input_all_embed = input_embed + pos_embed + seg_embed
        return self.norm(input_all_embed)

#点乘函数
def ScaledDotProductAttention(Q,K,V,attn_mask):
    '''
    Q:[batch_size,num_head,seq_len,query_dim]
    V:[batch_size,num_head,seq_len,value_dim]
    K:[batch_size,num_head,seq_len,key_dim]   其中query_dim=key_dim=value_dim
    attn_mask:[batch_size,num_head,seq_len,seq_len]
    '''
    scores = torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3))  #[batch_size,num_head,seq_len,seq_len]
    scores.masked_fill_(attn_mask,-1e9) #[batch_size,num_head,seq_len,seq_len]
    attn = nn.Softmax(-1)(scores)   #[batch_size,num_head,seq_len,seq_len]
    context = torch.matmul(attn,V)  #[batch_size,num_head,seq_len,value_dim]

    return context

# MultiHeadAttention层
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_head,query_dim,value_dim):
        super(MultiHeadAttention,self).__init__()
        self.W_Q = nn.Linear(embed_dim,num_head * query_dim)
        self.W_K = nn.Linear(embed_dim,num_head * query_dim)
        self.W_V = nn.Linear(embed_dim,num_head * value_dim)
        self.fc = nn.Linear(num_head * value_dim,embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.num_head = num_head
        self.query_dim = query_dim
        self.value_dim = value_dim

    def forward(self,input_q,input_k,input_v,attn_mask):
        "input_q:[batch_size,seq_len,embed_dim]"
        batch_size,residual = input_q.size(0),input_q

        #Q,K:[batch_size,num_head,seq_len,query_dim]
        #V:[batch_size,num_head,seq_len,value_dim]
        Q = self.W_Q(input_q).contiguous().view(batch_size,-1,self.num_head,self.query_dim).transpose(1,2) 
        K = self.W_K(input_k).contiguous().view(batch_size,-1,self.num_head,self.query_dim).transpose(1,2)
        V = self.W_V(input_v).contiguous().view(batch_size,-1,self.num_head,self.value_dim).transpose(1,2)

        #[batch_size,num_head,seq_len,seq_len]
        attn_mask = attn_mask.unsqueeze(1).repeat(1,self.num_head,1,1) 
        #[batch_size,seq_len,num_head * value_dim]
        outputs = ScaledDotProductAttention(Q,K,V,attn_mask).transpose(1,2).contiguous().view(batch_size,-1,self.num_head * self.value_dim)
        #[batch_size,seq_len,embed_dim]
        outputs = self.fc(outputs)

        return self.norm(outputs + residual)

# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self,embed_dim,ff_dim):
        super(FeedForward,self).__init__()
        self.fc1 = nn.Linear(embed_dim,ff_dim)
        self.fc2 = nn.Linear(ff_dim,embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self,self_attn_output):
        'self_attn_output:[batch_size,seq_len,embed_dim]'
        residual = self_attn_output
        ff_output = self.fc2(gelu(self.fc1(self_attn_output)))
        return self.norm(residual + ff_output)

# 一层Encoder            
class EncoderLayer(nn.Module):
    def __init__(self,multiheadattention,feedforward):
        super(EncoderLayer,self).__init__()
        self.multiheadattention = multiheadattention
        self.feedforward = feedforward

    def forward(self,input_all_embed,attn_mask):
        '''
        input_all_embed:[batch_size,seq_len,embed_dim]
        attn_mask:[batch_size,seq_len,seq_len]
        '''
        attn_output = self.multiheadattention(input_all_embed,input_all_embed,input_all_embed,attn_mask)
        output = self.feedforward(attn_output)
        return output  

# BERT
class BERT(nn.Module):
    def __init__(self,embedding,encoderlayer,num_layer,embed_dim,vocab_size):
        super(BERT,self).__init__()
        self.embed = embedding
        self.layers = nn.ModuleList([encoderlayer for _ in range(num_layer)])
        self.fc = nn.Sequential(
            nn.Linear(embed_dim,embed_dim),
            nn.Dropout(0.5),
            nn.Tanh()
        )
        self.classifier = nn.Linear(embed_dim, 2)

        self.mask_fc1 = nn.Linear(embed_dim, embed_dim)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        self.mask_fc2 = nn.Linear(embed_dim, vocab_size, bias=False)  

    def forward(self,input_ids,segment_ids,masked_pos):
        '''
        input_ids:[batch_size,seq_len]
        segment_ids:[batch_size,seq_len]
        '''
        input_all_embed = self.embed(input_ids,segment_ids)
        #[batch_size,seq_len,seq_len]
        attn_mask = get_attn_pad_mask(input_ids,input_ids)
        for layer in self.layers:
            output = layer(input_all_embed,attn_mask)

        #[batch_size,embed_dim]
        class_input = self.fc(output[:,0])
        #[batch_size,2]
        class_output = self.classifier(class_input)

        masked_pos = masked_pos[:,:,None].expand(-1,-1,output.size(2))
        #[batch_size,max_pred,embed_size]
        output_mask = torch.gather(output,1,masked_pos)
        #[batch_size,max_pred,vocab_size]
        output_mask = self.mask_fc2(self.activ2(self.mask_fc1(output_mask)))

        return class_output,output_mask

#模型的定义
feedforward = FeedForward(EMBED_DIM,FF_DIM)
multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEADS,KEY_DIM,VALUE_DIM)
embedding = Embedding(VOCAB_SIZE,MAX_LEN,NUM_SEGMENTS,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
model = BERT(embedding,encoderlayer,NUM_LAYERS,EMBED_DIM,VOCAB_SIZE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters(),lr=1e-2)

#模型的训练    
for epoch in range(100):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
      class_output, mask_output = model(input_ids, segment_ids, masked_pos)
      loss_lm = criterion(mask_output.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
      loss_clsf = criterion(class_output, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

#模型的测试             
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[5]
print(TEXT)
print('================================')
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])
class_output, mask_output = model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))

mask_output = mask_output.data.max(2)[1][0].data.numpy()
print(mask_output)
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in mask_output if pos != 0])

class_output = class_output.data.max(1)[1][0].data.numpy()
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if class_output else False)

这段代码中用到了一个激活函数gelu,这是 BERT 论文中提出来的,具体公式可以看这篇文章 GELU 激活函数

文中还用到了一个gather()函数,可参考gather函数的理解