B 站视频讲解
本文主要介绍一下如何使用 PyTorch 复现 BERT。请先花上 10 分钟阅读我的这篇文章 BERT 详解(附带 ELMo、GPT 介绍),再来看本文,方能达到醍醐灌顶,事半功倍的效果
准备数据集
这里我并没有用什么大型的数据集,而是手动输入了两个人的对话,主要是为了降低代码阅读难度,我希望读者能更关注模型实现的部分
#导入相关的库
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
#需要训练的文本
TEXT = (
'Hello, how are you? I am Romeo.\n' # R
'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
'Nice meet you too. How are you today?\n' # R
'Great. My baseball team won the competition.\n' # J
'Oh Congratulations, Juliet\n' # R
'Thank you Romeo\n' # J
'Where are you going today?\n' # R
'I am going shopping. What about you?\n' # J
'I am going to visit my grandmother. she is not very well' # R
)
#相关的参数
# BERT Parameters
MAX_LEN = 30
BATCH_SIZE = 6
MAX_PRED = 5 # max tokens of prediction
NUM_LAYERS = 6
NUM_HEADS = 12
EMBED_DIM = 768
FF_DIM = 768*4 # 4*d_model, FeedForward dimension
KEY_DIM = VALUE_DIM = 64 # dimension of K(=Q), V
NUM_SEGMENTS = 2
maxlen
表示同一个 batch 中的所有句子都由 30 个 token 组成,不够的补 PAD(这里我实现的方式比较粗暴,直接固定所有 batch 中的所有句子都为 30)max_pred
表示最多需要预测多少个单词,即 BERT 中的完形填空任务n_layers
表示 Encoder Layer 的数量d_model
表示 Token Embeddings、Segment Embeddings、Position Embeddings 的维度d_ff
表示 Encoder Layer 中全连接层的维度n_segments
表示 Decoder input 由几句话组成数据预处理
数据预处理部分,我们需要根据概率随机 make 或者替换(以下统称 mask)一句话中 15% 的 token,还需要拼接任意两句话
def make_data(text,max_len,batch_size,max_pred):
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
word_list = list(set(" ".join(sentences).split()))
word2idx={'[PAD]':0,'[CLS]':1,'[SEP]':2,'[MASK]':3}
for i,w in enumerate(word_list):
word2idx[w] = i + 4
idx2word = {i:w for i,w in enumerate(word2idx)}
vocab_size = len(word2idx)
token_list = []
for sentence in sentences:
sentence=[word2idx[w] for w in sentence.split()]
token_list.append(sentence)
batch = []
positive = negative = 0
while positive != batch_size/2 or negative != batch_size/2:
tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
# MASK LM
n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
cand_maked_pos = [i for i, token in enumerate(input_ids)
if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
shuffle(cand_maked_pos)
masked_tokens, masked_pos = [], []
for pos in cand_maked_pos[:n_pred]:
masked_pos.append(pos)
masked_tokens.append(input_ids[pos])
if random() < 0.8: # 80%
input_ids[pos] = word2idx['[MASK]'] # make mask
elif random() > 0.9: # 10%
index = randint(0, vocab_size - 1) # random index in vocabulary
while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
index = randint(0, vocab_size - 1)
input_ids[pos] = index # replace
# Zero Paddings
n_pad = max_len - len(input_ids)
input_ids.extend([0] * n_pad)
segment_ids.extend([0] * n_pad)
# Zero Padding (100% - 15%) tokens
if max_pred > n_pred:
n_pad = max_pred - n_pred
masked_tokens.extend([0] * n_pad)
masked_pos.extend([0] * n_pad)
if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
positive += 1
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
negative += 1
return batch,vocab_size,idx2word
# Proprecessing Finished
上述代码中,
positive
变量代表两句话是连续的个数,negative
代表两句话不是连续的个数,我们需要做到在一个 batch 中,这两个样本的比例为 1:1。随机选取的两句话是否连续,只要通过判断tokens_a_index + 1 == tokens_b_index
即可
然后是随机 mask 一些 token,n_pred
变量代表的是即将 mask 的 token 数量,cand_maked_pos
代表的是有哪些位置是候选的、可以 mask 的(因为像 [SEP],[CLS] 这些不能做 mask,没有意义),最后shuffle()
一下,然后根据random()
的值选择是替换为[MASK]
还是替换为其它的 token
接下来会做两个 Zero Padding,第一个是为了补齐句子的长度,使得一个 batch 中的句子都是相同长度。第二个是为了补齐 mask 的数量,因为不同句子长度,会导致不同数量的单词进行 mask,我们需要保证同一个 batch 中,mask 的数量(必须)是相同的,所以也需要在后面补一些没有意义的东西,比方说[0]
以上就是整个数据预处理的部分
模型构建
模型结构主要采用了 Transformer 的 Encoder,所以这里我不再多赘述,可以直接看我的这篇文章 Transformer 的 PyTorch 实现,以及 B 站视频讲解
batch,VOCAB_SIZE,idx2word=make_data(TEXT,MAX_LEN,BATCH_SIZE,MAX_PRED)
batch[5][1]
input_ids, segment_ids, masked_tokens, masked_pos, isNext=zip(*batch)
input_ids = torch.LongTensor(input_ids)
segment_ids = torch.LongTensor(segment_ids)
masked_tokens = torch.LongTensor(masked_tokens)
masked_pos = torch.LongTensor(masked_pos)
isNext = torch.LongTensor(isNext)
#定义数据加载器
class MyDataSet(Data.Dataset):
def __init__(self,input_ids,segment_ids,masked_tokens,masked_pos,isNext):
super(MyDataSet,self).__init__()
self.input_ids = input_ids
self.segment_ids = segment_ids
self.masked_tokens = masked_tokens
self.masked_pos = masked_pos
self.isNext = isNext
def __len__(self):
return len(self.input_ids)
def __getitem__(self,idx):
return self.input_ids[idx],self.segment_ids[idx],self.masked_tokens[idx],self.masked_pos[idx],self.isNext[idx]
loader=Data.DataLoader(MyDataSet(input_ids,segment_ids,masked_tokens,masked_pos,isNext),batch_size=BATCH_SIZE,shuffle=True)
#定义attn_mask函数
def get_attn_pad_mask(seq_q,seq_k):
'''
seq_q:[batch_size,seq_len]
seq_k:[batch_size,seq_len]
'''
batch_size,seq_len=seq_q.size()
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return pad_attn_mask.expand(batch_size,seq_len,seq_len)
#定义激活函数
def gelu(x):
"""
Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
#定义Embedding层,获取输入
class Embedding(nn.Module):
def __init__(self,vocab_size,max_len,num_segments,embed_dim):
super(Embedding,self).__init__()
self.input_embed = nn.Embedding(vocab_size,embed_dim)
self.pos_embed = nn.Embedding(max_len,embed_dim)
self.seg_embed = nn.Embedding(num_segments,embed_dim)
self.norm = nn.LayerNorm(embed_dim)
def forward(self,input,segment):
'''
input:[batch_size,max_len]
segment:[batch_size,max_len]
'''
#input_embed:[batch_size,max_len,embed_dim]
input_embed = self.input_embed(input)
seq_len = input.size(1)
pos_embed = np.arange(seq_len)
pos_embed = torch.LongTensor(pos_embed)
#[batch_size,max_len]
pos_embed = pos_embed.unsqueeze(0).expand_as(input)
#[batch_size,max_len,embed_dim]
pos_embed = self.pos_embed(pos_embed)
seg_embed = self.seg_embed(segment)
input_all_embed = input_embed + pos_embed + seg_embed
return self.norm(input_all_embed)
#点乘函数
def ScaledDotProductAttention(Q,K,V,attn_mask):
'''
Q:[batch_size,num_head,seq_len,query_dim]
V:[batch_size,num_head,seq_len,value_dim]
K:[batch_size,num_head,seq_len,key_dim] 其中query_dim=key_dim=value_dim
attn_mask:[batch_size,num_head,seq_len,seq_len]
'''
scores = torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3)) #[batch_size,num_head,seq_len,seq_len]
scores.masked_fill_(attn_mask,-1e9) #[batch_size,num_head,seq_len,seq_len]
attn = nn.Softmax(-1)(scores) #[batch_size,num_head,seq_len,seq_len]
context = torch.matmul(attn,V) #[batch_size,num_head,seq_len,value_dim]
return context
# MultiHeadAttention层
class MultiHeadAttention(nn.Module):
def __init__(self,embed_dim,num_head,query_dim,value_dim):
super(MultiHeadAttention,self).__init__()
self.W_Q = nn.Linear(embed_dim,num_head * query_dim)
self.W_K = nn.Linear(embed_dim,num_head * query_dim)
self.W_V = nn.Linear(embed_dim,num_head * value_dim)
self.fc = nn.Linear(num_head * value_dim,embed_dim)
self.norm = nn.LayerNorm(embed_dim)
self.num_head = num_head
self.query_dim = query_dim
self.value_dim = value_dim
def forward(self,input_q,input_k,input_v,attn_mask):
"input_q:[batch_size,seq_len,embed_dim]"
batch_size,residual = input_q.size(0),input_q
#Q,K:[batch_size,num_head,seq_len,query_dim]
#V:[batch_size,num_head,seq_len,value_dim]
Q = self.W_Q(input_q).contiguous().view(batch_size,-1,self.num_head,self.query_dim).transpose(1,2)
K = self.W_K(input_k).contiguous().view(batch_size,-1,self.num_head,self.query_dim).transpose(1,2)
V = self.W_V(input_v).contiguous().view(batch_size,-1,self.num_head,self.value_dim).transpose(1,2)
#[batch_size,num_head,seq_len,seq_len]
attn_mask = attn_mask.unsqueeze(1).repeat(1,self.num_head,1,1)
#[batch_size,seq_len,num_head * value_dim]
outputs = ScaledDotProductAttention(Q,K,V,attn_mask).transpose(1,2).contiguous().view(batch_size,-1,self.num_head * self.value_dim)
#[batch_size,seq_len,embed_dim]
outputs = self.fc(outputs)
return self.norm(outputs + residual)
# 前馈神经网络
class FeedForward(nn.Module):
def __init__(self,embed_dim,ff_dim):
super(FeedForward,self).__init__()
self.fc1 = nn.Linear(embed_dim,ff_dim)
self.fc2 = nn.Linear(ff_dim,embed_dim)
self.norm = nn.LayerNorm(embed_dim)
def forward(self,self_attn_output):
'self_attn_output:[batch_size,seq_len,embed_dim]'
residual = self_attn_output
ff_output = self.fc2(gelu(self.fc1(self_attn_output)))
return self.norm(residual + ff_output)
# 一层Encoder
class EncoderLayer(nn.Module):
def __init__(self,multiheadattention,feedforward):
super(EncoderLayer,self).__init__()
self.multiheadattention = multiheadattention
self.feedforward = feedforward
def forward(self,input_all_embed,attn_mask):
'''
input_all_embed:[batch_size,seq_len,embed_dim]
attn_mask:[batch_size,seq_len,seq_len]
'''
attn_output = self.multiheadattention(input_all_embed,input_all_embed,input_all_embed,attn_mask)
output = self.feedforward(attn_output)
return output
# BERT
class BERT(nn.Module):
def __init__(self,embedding,encoderlayer,num_layer,embed_dim,vocab_size):
super(BERT,self).__init__()
self.embed = embedding
self.layers = nn.ModuleList([encoderlayer for _ in range(num_layer)])
self.fc = nn.Sequential(
nn.Linear(embed_dim,embed_dim),
nn.Dropout(0.5),
nn.Tanh()
)
self.classifier = nn.Linear(embed_dim, 2)
self.mask_fc1 = nn.Linear(embed_dim, embed_dim)
self.activ2 = gelu
# fc2 is shared with embedding layer
self.mask_fc2 = nn.Linear(embed_dim, vocab_size, bias=False)
def forward(self,input_ids,segment_ids,masked_pos):
'''
input_ids:[batch_size,seq_len]
segment_ids:[batch_size,seq_len]
'''
input_all_embed = self.embed(input_ids,segment_ids)
#[batch_size,seq_len,seq_len]
attn_mask = get_attn_pad_mask(input_ids,input_ids)
for layer in self.layers:
output = layer(input_all_embed,attn_mask)
#[batch_size,embed_dim]
class_input = self.fc(output[:,0])
#[batch_size,2]
class_output = self.classifier(class_input)
masked_pos = masked_pos[:,:,None].expand(-1,-1,output.size(2))
#[batch_size,max_pred,embed_size]
output_mask = torch.gather(output,1,masked_pos)
#[batch_size,max_pred,vocab_size]
output_mask = self.mask_fc2(self.activ2(self.mask_fc1(output_mask)))
return class_output,output_mask
#模型的定义
feedforward = FeedForward(EMBED_DIM,FF_DIM)
multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEADS,KEY_DIM,VALUE_DIM)
embedding = Embedding(VOCAB_SIZE,MAX_LEN,NUM_SEGMENTS,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
model = BERT(embedding,encoderlayer,NUM_LAYERS,EMBED_DIM,VOCAB_SIZE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters(),lr=1e-2)
#模型的训练
for epoch in range(100):
for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
class_output, mask_output = model(input_ids, segment_ids, masked_pos)
loss_lm = criterion(mask_output.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
loss_lm = (loss_lm.float()).mean()
loss_clsf = criterion(class_output, isNext) # for sentence classification
loss = loss_lm + loss_clsf
if (epoch + 1) % 10 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
#模型的测试
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[5]
print(TEXT)
print('================================')
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])
class_output, mask_output = model(torch.LongTensor([input_ids]), \
torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
mask_output = mask_output.data.max(2)[1][0].data.numpy()
print(mask_output)
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in mask_output if pos != 0])
class_output = class_output.data.max(1)[1][0].data.numpy()
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if class_output else False)
这段代码中用到了一个激活函数gelu,这是 BERT 论文中提出来的,具体公式可以看这篇文章 GELU 激活函数。
文中还用到了一个gather()函数,可参考gather函数的理解