B 站视频讲解
本文主要介绍一下如何使用 PyTorch 复现 Transformer,实现简单的机器翻译任务。请先花上 15 分钟阅读我的这篇文章 Transformer 详解,再来看本文,方能达到醍醐灌顶,事半功倍的效果
数据预处理
这里我并没有用什么大型的数据集,而是手动输入了两对德语→英语的句子,还有每个字的索引也是我手动硬编码上去的,主要是为了降低代码阅读难度,我希望读者能更关注模型实现的部分
#导入相关的库
import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
import math
import torch.optim as optim
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
# enc_input dec_input dec_output
['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]
# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
SRC_VOCAB_SIZE = len(src_vocab)
tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
TGT_VOCAB_SIZE = len(tgt_vocab)
SRC_LEN = 5 # enc_input max sequence length
TGT_LEN = 6 # dec_input(=dec_output) max sequence length
def make_data(sentences):
enc_inputs, dec_inputs, dec_outputs = [], [], []
for i in range(len(sentences)):
enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
#enc_inputs:[2,5]
#dec_inputs:[2,6]
#dec_output:[2,6]
enc_inputs.extend(enc_input)
dec_inputs.extend(dec_input)
dec_outputs.extend(dec_output)
return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
class MyDataSet(Data.Dataset):
def __init__(self, enc_inputs, dec_inputs, dec_outputs):
super(MyDataSet, self).__init__()
self.enc_inputs = enc_inputs
self.dec_inputs = dec_inputs
self.dec_outputs = dec_outputs
def __len__(self):
return self.enc_inputs.shape[0]
def __getitem__(self, idx):
return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
模型参数
下面变量代表的含义依次是
- 字嵌入 & 位置嵌入的维度,这俩值是相同的,因此用一个变量就行了
- FeedForward 层隐藏神经元个数
- Q、K、V 向量的维度,其中 Q 与 K 的维度必须相等,V 的维度没有限制,不过为了方便起见,我都设为 64
- Encoder 和 Decoder 的个数
- 多头注意力中 head 的数量
上面都比较简单,下面开始涉及到模型就比较复杂了,因此我会将模型拆分成以下几个部分进行讲解# Transformer Parameters EMBED_DIM = 512 # Embedding Size FF_DIM = 2048 # FeedForward dimension KEY_DIM = VALUE_DIM = 64 # dimension of K(=Q), V NUM_LAYER = 6 # number of Encoder of Decoder Layer NUM_HEAD = 8 # number of heads in Multi-Head Attention
- Positional Encoding
- Pad Mask(针对句子不够长,加了 pad,因此需要对 pad 进行 mask)
- Subsequence Mask(Decoder input 不能看到未来时刻单词信息,因此需要 mask)
- ScaledDotProductAttention(计算 context vector)
- Multi-Head Attention
- FeedForward Layer
- Encoder Layer
- Encoder
- Decoder Layer
- Decoder
- Transformer
关于代码中的注释,如果值为src_len
或者tgt_len
的,我一定会写清楚,但是有些函数或者类,Encoder 和 Decoder 都有可能调用,因此就不能确定究竟是src_len
还是tgt_len
,对于不确定的,我会记作seq_len
Positional Encoding
def get_positional_encoding(max_seq_len, embed_dim):
# 初始化一个positional encoding
# embed_dim: 字嵌入的维度
# max_seq_len: 最大的序列长度
positional_encoding = np.array([
[pos / np.power(10000, 2 * (i//2) / embed_dim) for i in range(embed_dim)]
if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])
positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2]) # dim 2i 偶数
positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2]) # dim 2i+1 奇数
return torch.FloatTensor(positional_encoding)
这段代码并不复杂,传入的参数分别是字库的大小,以及位置编码的维度。最终返回的参数,维度也即 [n_position, d_model],和字编码的大小一摸一样
Pad Mask
def get_attn_pad_mask(seq_q, seq_k):
'''
seq_q: [batch_size, seq_len]
seq_k: [batch_size, seq_len]
seq_len could be src_len or it could be tgt_len
seq_len in seq_q and seq_len in seq_k maybe not equal
'''
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked
return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]
由于在 Encoder 和 Decoder 中都需要进行 mask 操作,因此就无法确定这个函数的参数中seq_len
的值,如果是在 Encoder 中调用的,seq_len
就等于src_len
;如果是在 Decoder 中调用的,seq_len
就有可能等于src_len
,也有可能等于tgt_len
(因为 Decoder 有两次 mask)
这个函数最核心的一句代码是seq_k.data.eq(0)
,这句的作用是返回一个大小和seq_k
一样的 tensor,只不过里面的值只有 True 和 False。如果seq_k
某个位置的值等于 0,那么对应位置就是 True,否则即为 False。举个例子,输入为seq_data = [1, 2, 3, 4, 0]
,seq_data.data.eq(0)
就会返回[False, False, False, False, True]
剩下的代码主要是扩展维度,强烈建议读者打印出来,看看最终返回的数据是什么样子
Subsequence Mask
def get_attn_subsequence_mask(seq):
"seq:[batch_size,dec_seq_len]"
attn_shape=[seq.shape[0],seq.shape[1],seq.shape[1]]
subsequence_mask=np.triu(np.ones(attn_shape),1)
subsequence_mask=torch.from_numpy(subsequence_mask).byte()
return subsequence_mask
Subsequence Mask 只有 Decoder 会用到,主要作用是屏蔽未来时刻单词的信息。首先通过np.ones()
生成一个全 1 的方阵,然后通过np.triu()
生成一个上三角矩阵,下图是np.triu()
用法
ScaledDotProductAttention
def ScaledDotProductAttention(Q,K,V,attn_mask):
'''
Q:[batch_size,head,len_q,d_k]
K:[batch_size,head,len_k,d_k]
V:[batch_size,head,len_v,d_k]
attn_mask:[batch_size,head,len_q,len_k]
'''
scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3))
scores.masked_fill_(attn_mask,-1e9)
attn=torch.nn.Softmax(-1)(scores)
context=torch.matmul(attn,V)
return context,attn
这里要做的是,通过Q
和K
计算出scores
,然后将scores
和V
相乘,得到每个单词的 context vector
第一步是将Q
和K
的转置相乘没什么好说的,相乘之后得到的scores
还不能立刻进行 softmax,需要和attn_mask
相加,把一些需要屏蔽的信息屏蔽掉,attn_mask
是一个仅由 True 和 False 组成的 tensor,并且一定会保证attn_mask
和scores
的维度四个值相同(不然无法做对应位置相加)
mask 完了之后,就可以对scores
进行 softmax 了。然后再与V
相乘,得到context
MultiHeadAttention
class MultiHeadAttention(nn.Module):
def __init__(self,embed_dim,num_head,query_dim,value_dim):
super(MultiHeadAttention,self).__init__()
self.head=num_head
self.query=query_dim
self.value=value_dim
self.W_Q=nn.Linear(embed_dim,num_head*query_dim,bias=True)
self.W_K=nn.Linear(embed_dim,num_head*query_dim,bias=True)
self.W_V=nn.Linear(embed_dim,num_head*value_dim,bias=True)
self.fc=nn.Linear(num_head*value_dim,embed_dim)
def forward(self,inputs_query,inputs_key,inputs_value,attn_mask):
'''
inputs_query:[batch_size,len_q,embed_dim]
inputs_key:[batch_size,len_k,embed_dim]
inputs_value:[batch_size,len_v,embed_dim]
attn_mask:[batch_size,len_q,len_k]
'''
residual,batch_size=inputs_query,inputs_query.size(0)
#Q:[batch_size,num_head,len_q,query_dim]
Q=self.W_Q(inputs_query).view(batch_size,-1,self.head,self.query).transpose(1,2)
#K:[batch_size,num_head,len_k,query_dim]
K=self.W_K(inputs_key).view(batch_size,-1,self.head,self.query).transpose(1,2)
#V:[batch_size,num_head,len_v,value_dim]
V=self.W_V(inputs_value).view(batch_size,-1,self.head,self.value).transpose(1,2)
#attn_mask:[batch_size,num_head,len_q,len_k]
attn_mask=attn_mask.unsqueeze(1).repeat(1,self.head,1,1)
#context:[batch_size,num_head,len_q,value_dim]
context,attn=ScaledDotProductAttention(Q,K,V,attn_mask)
context=context.transpose(1,2).contiguous().view(batch_size,-1,self.head*self.value)
output=self.fc(context)
return nn.LayerNorm(output.size(2))(output+residual),attn
完整代码中一定会有三处地方调用MultiHeadAttention()
,Encoder Layer 调用一次,传入的input_Q
、input_K
、input_V
全部都是enc_inputs
;Decoder Layer 中两次调用,第一次传入的全是dec_inputs
,第二次传入的分别是dec_outputs
,enc_outputs
,enc_outputs
FeedForward Layer
class FeedForward(nn.Module):
def __init__(self,embed_dim,feedforward_dim):
super(FeedForward,self).__init__()
self.fc=nn.Sequential(
nn.Linear(embed_dim,feedforward_dim,bias=True),
nn.ReLU(),
nn.Linear(feedforward_dim,embed_dim,bias=True)
)
def forward(self,inputs):
"inputs:[batch_size,len_q,embed_size]"
residual=inputs
outputs=self.fc(inputs)
return nn.LayerNorm(inputs.size(2))(residual + outputs)#size:[batch_size,len_q,embed_size]
这段代码非常简单,就是做两次线性变换,残差连接后再跟一个 Layer Norm
Encoder Layer
class EncoderLayer(nn.Module):
def __init__(self,multiheadattention,feedforward):
super(EncoderLayer,self).__init__()
self.enc_self_attention=multiheadattention
self.feedforward=feedforward
def forward(self,enc_inputs,enc_attn_mask):
'''
enc_inputs:[batch_size,seq_len,embed_size]
enc_attn_mask:[batch_size,len_q,len_k]
'''
#outputs:[batch_size,seq_len,embed_size]
#attn:[batch_size,len_q,len_k]
outputs,attn=self.enc_self_attention(enc_inputs,enc_inputs,enc_inputs,enc_attn_mask)
outputs=self.feedforward(outputs)
return outputs,attn
将上述组件拼起来,就是一个完整的 Encoder Layer
Encoder
class Encoder(nn.Module):
def __init__(self,src_vocab_size,embed_dim,pos_embed,encoderlayer,num_layer):
super(Encoder,self).__init__()
self.src_embed=nn.Embedding(src_vocab_size,embed_dim)
self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
self.layer=nn.ModuleList([encoderlayer for _ in range(num_layer)])
def forward(self,enc_inputs):
"enc_inputs:[batch_size,seq_len]"
seq_len = enc_inputs.size(1)
pos = torch.LongTensor([i for i in range(seq_len)])
pos = pos.unsqueeze(0).expand_as(enc_inputs)
word_embed=self.src_embed(enc_inputs)
pos_embed=self.pos_embed(pos)
input_embed=word_embed+pos_embed
enc_attn_mask=get_attn_pad_mask(enc_inputs,enc_inputs)
enc_attn=[]
for layer in self.layer:
outputs,attn=layer(input_embed,enc_attn_mask)
input_embed = outputs
enc_attn.append(attn)
return outputs,enc_attn
使用nn.ModuleList()
里面的参数是列表,列表里面存了n_layers
个 Encoder Layer
由于我们控制好了 Encoder Layer 的输入和输出维度相同,所以可以直接用个 for 循环以嵌套的方式,将上一次 Encoder Layer 的输出作为下一次 Encoder Layer 的输入
Decoder Layer
class DecoderLayer(nn.Module):
def __init__(self,multiheadattention,feedforward):
super(DecoderLayer,self).__init__()
self.dec_self_attn=multiheadattention
self.enc_dec_self_attn=multiheadattention
self.feedforward=feedforward
def forward(self,dec_inputs,enc_outputs,dec_attn_mask,enc_dec_attn_mask):
'''
dec_inputs:[batch_size,tgt_len,embed_dim]
enc_outputs:[batch_size,src_len,embed_dim]
dec_attn_mask:[batch_size,tgt_len,tgt_len]
end_dec_attn_mask:[batch_size,tgt_len,src_len]
'''
#dec_outputs:[batch_size,tgt_len,embed_dim]
dec_outputs,dec_attn=self.dec_self_attn(dec_inputs,dec_inputs,dec_inputs,dec_attn_mask)
dec_outputs,enc_dec_attn=self.enc_dec_self_attn(dec_outputs,enc_outputs,enc_outputs,enc_dec_attn_mask)
dec_outputs=self.feedforward(dec_outputs)
return dec_outputs,dec_attn,enc_dec_attn
在 Decoder Layer 中会调用两次MultiHeadAttention
,第一次是计算 Decoder Input 的 self-attention,得到输出dec_outputs
。然后将dec_outputs
作为生成 Q 的元素,enc_outputs
作为生成 K 和 V 的元素,再调用一次MultiHeadAttention
,得到的是 Encoder 和 Decoder Layer 之间的 context vector。最后将dec_outptus
做一次维度变换,然后返回
Decoder
class Decoder(nn.Module):
def __init__(self,tgt_vocab_size,embed_dim,pos_embed,decoderlayer,num_layer):
super(Decoder,self).__init__()
self.tgt_embed=nn.Embedding(tgt_vocab_size,embed_dim)
self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
self.layers=nn.ModuleList([decoderlayer for _ in range(num_layer)])
def forward(self,dec_inputs,enc_inputs,enc_outputs):
"dec_inputs:[batch_size,tgt_len]"
seq_len = dec_inputs.size(1)
pos = torch.LongTensor([i for i in range(seq_len)])
pos = pos.unsqueeze(0).expand_as(dec_inputs)
tgt_embed=self.tgt_embed(dec_inputs)
pos_embed=self.pos_embed(pos)
input_embed=tgt_embed+pos_embed
dec_self_attn_pad_mask=get_attn_pad_mask(dec_inputs,dec_inputs)
dec_subsequence_attn_mask=get_attn_subsequence_mask(dec_inputs)
dec_self_attn_mask=dec_self_attn_pad_mask+dec_subsequence_attn_mask
dec_self_attn_mask=torch.gt(dec_self_attn_mask,0)
enc_dec_self_attn_mask=get_attn_pad_mask(dec_inputs,enc_inputs)
dec_attn_all=[]
enc_dec_attn_all=[]
for layer in self.layers:
outputs,dec_attn,enc_dec_attn=layer(input_embed,enc_outputs,dec_self_attn_mask,enc_dec_self_attn_mask)
input_embed = outputs
dec_attn_all.append(dec_attn)
enc_dec_attn_all.append(enc_dec_attn)
return outputs,dec_attn_all,enc_dec_attn_all
Decoder 中不仅要把 “pad”mask 掉,还要 mask 未来时刻的信息,因此就有了下面这三行代码,其中torch.gt(a, value)
的意思是,将 a 中各个位置上的元素和 value 比较,若大于 value,则该位置取 1,否则取 0
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]
Transformer
class Transformer(nn.Module):
def __init__(self,encoder,decoder,embed_dim,tgt_vocab_size):
super(Transformer,self).__init__()
self.encoder=encoder
self.decoder=decoder
self.fc=nn.Linear(embed_dim,tgt_vocab_size)
def forward(self,enc_inputs,dec_inputs):
'''
enc_inputs:[batch_size,src_len]
dec_inputs:[batch_size,tgt_len]
'''
enc_outputs,enc_attn_all=self.encoder(enc_inputs)
dec_outputs,dec_attn_all,enc_dec_attn_all=self.decoder(dec_inputs,enc_inputs,enc_outputs)
dec_outputs=self.fc(dec_outputs)
dec_outputs=dec_outputs.view(-1,dec_outputs.size(-1))
return dec_outputs,dec_attn_all,enc_dec_attn_all,enc_attn_all
Transformer 主要就是调用 Encoder 和 Decoder。最后返回dec_logits
的维度是 [batch_size tgt_len, tgt_vocab_size],可以理解为,一个句子,这个句子有 batch_sizetgt_len 个单词,每个单词有 tgt_vocab_size 种情况,取概率最大者
模型 & 损失函数 & 优化器
multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEAD,KEY_DIM,VALUE_DIM)
feedforward = FeedForward(EMBED_DIM,FF_DIM)
enc_pos_embed = get_positional_encoding(SRC_LEN,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
dec_pos_embed = get_positional_encoding(TGT_LEN,EMBED_DIM)
decoderlayer = DecoderLayer(multiheadattention,feedforward)
encoder = Encoder(SRC_VOCAB_SIZE,EMBED_DIM,enc_pos_embed,encoderlayer,NUM_LAYER)
decoder = Decoder(TGT_VOCAB_SIZE,EMBED_DIM,dec_pos_embed,decoderlayer,NUM_LAYER)
model = Transformer(encoder,decoder,EMBED_DIM,TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
这里的损失函数里面我设置了一个参数ignore_index=0
,因为 “pad” 这个单词的索引为 0,这样设置以后,就不会计算 “pad” 的损失(因为本来 “pad” 也没有意义,不需要计算),关于这个参数更详细的说明,可以看我这篇文章的最下面,稍微提了一下
训练
for epoch in range(30):
for enc_inputs, dec_inputs, dec_outputs in loader:
'''
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, tgt_len]
dec_outputs: [batch_size, tgt_len]
'''
# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)
# outputs: [batch_size * tgt_len, tgt_vocab_size]
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
loss = criterion(outputs, dec_outputs.view(-1))
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
测试
enc_inputs, dec_inputs, _ = next(iter(loader))
predict, _, _, _ = model(enc_inputs[0].view(1, -1), dec_inputs[0].view(1, -1)) # model(enc_inputs[0].view(1, -1), greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(enc_inputs[0], '->', [idx2word[n.item()] for n in predict.squeeze()])
完整代码
#导入相关的库
import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
import math
import torch.optim as optim
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
# enc_input dec_input dec_output
['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]
# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
SRC_VOCAB_SIZE = len(src_vocab)
tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
TGT_VOCAB_SIZE = len(tgt_vocab)
SRC_LEN = 5 # enc_input max sequence length
TGT_LEN = 6 # dec_input(=dec_output) max sequence length
def make_data(sentences):
enc_inputs, dec_inputs, dec_outputs = [], [], []
for i in range(len(sentences)):
enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
#enc_inputs:[2,5]
#dec_inputs:[2,6]
#dec_output:[2,6]
enc_inputs.extend(enc_input)
dec_inputs.extend(dec_input)
dec_outputs.extend(dec_output)
return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
class MyDataSet(Data.Dataset):
def __init__(self, enc_inputs, dec_inputs, dec_outputs):
super(MyDataSet, self).__init__()
self.enc_inputs = enc_inputs
self.dec_inputs = dec_inputs
self.dec_outputs = dec_outputs
def __len__(self):
return self.enc_inputs.shape[0]
def __getitem__(self, idx):
return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)
# Transformer Parameters
EMBED_DIM = 512 # Embedding Size
FF_DIM = 2048 # FeedForward dimension
KEY_DIM = VALUE_DIM = 64 # dimension of K(=Q), V
NUM_LAYER = 6 # number of Encoder of Decoder Layer
NUM_HEAD = 8 # number of heads in Multi-Head Attention
def get_positional_encoding(max_seq_len, embed_dim):
# 初始化一个positional encoding
# embed_dim: 字嵌入的维度
# max_seq_len: 最大的序列长度
positional_encoding = np.array([
[pos / np.power(10000, 2 * (i//2) / embed_dim) for i in range(embed_dim)]
if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])
positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2]) # dim 2i 偶数
positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2]) # dim 2i+1 奇数
return torch.FloatTensor(positional_encoding)
def get_attn_pad_mask(seq_q, seq_k):
'''
seq_q: [batch_size, seq_len]
seq_k: [batch_size, seq_len]
seq_len could be src_len or it could be tgt_len
seq_len in seq_q and seq_len in seq_k maybe not equal
'''
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # [batch_size, 1, len_k], False is masked
return pad_attn_mask.expand(batch_size, len_q, len_k) # [batch_size, len_q, len_k]
def get_attn_subsequence_mask(seq):
"seq:[batch_size,dec_seq_len]"
attn_shape=[seq.shape[0],seq.shape[1],seq.shape[1]]
subsequence_mask=np.triu(np.ones(attn_shape),1)
subsequence_mask=torch.from_numpy(subsequence_mask).byte()
return subsequence_mask
def ScaledDotProductAttention(Q,K,V,attn_mask):
'''
Q:[batch_size,head,len_q,d_k]
K:[batch_size,head,len_k,d_k]
V:[batch_size,head,len_v,d_k]
attn_mask:[batch_size,head,len_q,len_k]
'''
scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3))
scores.masked_fill_(attn_mask,-1e9)
attn=torch.nn.Softmax(-1)(scores)
context=torch.matmul(attn,V)
return context,attn
class MultiHeadAttention(nn.Module):
def __init__(self,embed_dim,num_head,query_dim,value_dim):
super(MultiHeadAttention,self).__init__()
self.head=num_head
self.query=query_dim
self.value=value_dim
self.W_Q=nn.Linear(embed_dim,num_head*query_dim,bias=True)
self.W_K=nn.Linear(embed_dim,num_head*query_dim,bias=True)
self.W_V=nn.Linear(embed_dim,num_head*value_dim,bias=True)
self.fc=nn.Linear(num_head*value_dim,embed_dim)
def forward(self,inputs_query,inputs_key,inputs_value,attn_mask):
'''
inputs_query:[batch_size,len_q,embed_dim]
inputs_key:[batch_size,len_k,embed_dim]
inputs_value:[batch_size,len_v,embed_dim]
attn_mask:[batch_size,len_q,len_k]
'''
residual,batch_size=inputs_query,inputs_query.size(0)
#Q:[batch_size,num_head,len_q,query_dim]
Q=self.W_Q(inputs_query).view(batch_size,-1,self.head,self.query).transpose(1,2)
#K:[batch_size,num_head,len_k,query_dim]
K=self.W_K(inputs_key).view(batch_size,-1,self.head,self.query).transpose(1,2)
#V:[batch_size,num_head,len_v,value_dim]
V=self.W_V(inputs_value).view(batch_size,-1,self.head,self.value).transpose(1,2)
#attn_mask:[batch_size,num_head,len_q,len_k]
attn_mask=attn_mask.unsqueeze(1).repeat(1,self.head,1,1)
#context:[batch_size,num_head,len_q,value_dim]
context,attn=ScaledDotProductAttention(Q,K,V,attn_mask)
context=context.transpose(1,2).contiguous().view(batch_size,-1,self.head*self.value)
output=self.fc(context)
return nn.LayerNorm(output.size(2))(output+residual),attn
class FeedForward(nn.Module):
def __init__(self,embed_dim,feedforward_dim):
super(FeedForward,self).__init__()
self.fc=nn.Sequential(
nn.Linear(embed_dim,feedforward_dim,bias=True),
nn.ReLU(),
nn.Linear(feedforward_dim,embed_dim,bias=True)
)
def forward(self,inputs):
"inputs:[batch_size,len_q,embed_size]"
residual=inputs
outputs=self.fc(inputs)
return nn.LayerNorm(inputs.size(2))(residual + outputs)#size:[batch_size,len_q,embed_size]
class EncoderLayer(nn.Module):
def __init__(self,multiheadattention,feedforward):
super(EncoderLayer,self).__init__()
self.enc_self_attention=multiheadattention
self.feedforward=feedforward
def forward(self,enc_inputs,enc_attn_mask):
'''
enc_inputs:[batch_size,seq_len,embed_size]
enc_attn_mask:[batch_size,len_q,len_k]
'''
#outputs:[batch_size,seq_len,embed_size]
#attn:[batch_size,len_q,len_k]
outputs,attn=self.enc_self_attention(enc_inputs,enc_inputs,enc_inputs,enc_attn_mask)
outputs=self.feedforward(outputs)
return outputs,attn
class Encoder(nn.Module):
def __init__(self,src_vocab_size,embed_dim,pos_embed,encoderlayer,num_layer):
super(Encoder,self).__init__()
self.src_embed=nn.Embedding(src_vocab_size,embed_dim)
self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
self.layer=nn.ModuleList([encoderlayer for _ in range(num_layer)])
def forward(self,enc_inputs):
"enc_inputs:[batch_size,seq_len]"
seq_len = enc_inputs.size(1)
pos = torch.LongTensor([i for i in range(seq_len)])
pos = pos.unsqueeze(0).expand_as(enc_inputs)
word_embed=self.src_embed(enc_inputs)
pos_embed=self.pos_embed(pos)
input_embed=word_embed+pos_embed
enc_attn_mask=get_attn_pad_mask(enc_inputs,enc_inputs)
enc_attn=[]
for layer in self.layer:
outputs,attn=layer(input_embed,enc_attn_mask)
enc_attn.append(attn)
return outputs,enc_attn
class DecoderLayer(nn.Module):
def __init__(self,multiheadattention,feedforward):
super(DecoderLayer,self).__init__()
self.dec_self_attn=multiheadattention
self.enc_dec_self_attn=multiheadattention
self.feedforward=feedforward
def forward(self,dec_inputs,enc_outputs,dec_attn_mask,enc_dec_attn_mask):
'''
dec_inputs:[batch_size,tgt_len,embed_dim]
enc_outputs:[batch_size,src_len,embed_dim]
dec_attn_mask:[batch_size,tgt_len,tgt_len]
end_dec_attn_mask:[batch_size,tgt_len,src_len]
'''
#dec_outputs:[batch_size,tgt_len,embed_dim]
dec_outputs,dec_attn=self.dec_self_attn(dec_inputs,dec_inputs,dec_inputs,dec_attn_mask)
dec_outputs,enc_dec_attn=self.enc_dec_self_attn(dec_outputs,enc_outputs,enc_outputs,enc_dec_attn_mask)
dec_outputs=self.feedforward(dec_outputs)
return dec_outputs,dec_attn,enc_dec_attn
class Decoder(nn.Module):
def __init__(self,tgt_vocab_size,embed_dim,pos_embed,decoderlayer,num_layer):
super(Decoder,self).__init__()
self.tgt_embed=nn.Embedding(tgt_vocab_size,embed_dim)
self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
self.layers=nn.ModuleList([decoderlayer for _ in range(num_layer)])
def forward(self,dec_inputs,enc_inputs,enc_outputs):
"dec_inputs:[batch_size,tgt_len]"
seq_len = dec_inputs.size(1)
pos = torch.LongTensor([i for i in range(seq_len)])
pos = pos.unsqueeze(0).expand_as(dec_inputs)
tgt_embed=self.tgt_embed(dec_inputs)
pos_embed=self.pos_embed(pos)
input_embed=tgt_embed+pos_embed
dec_self_attn_pad_mask=get_attn_pad_mask(dec_inputs,dec_inputs)
dec_subsequence_attn_mask=get_attn_subsequence_mask(dec_inputs)
dec_self_attn_mask=dec_self_attn_pad_mask+dec_subsequence_attn_mask
dec_self_attn_mask=torch.gt(dec_self_attn_mask,0)
enc_dec_self_attn_mask=get_attn_pad_mask(dec_inputs,enc_inputs)
dec_attn_all=[]
enc_dec_attn_all=[]
for layer in self.layers:
outputs,dec_attn,enc_dec_attn=layer(input_embed,enc_outputs,dec_self_attn_mask,enc_dec_self_attn_mask)
dec_attn_all.append(dec_attn)
enc_dec_attn_all.append(enc_dec_attn)
return outputs,dec_attn_all,enc_dec_attn_all
class Transformer(nn.Module):
def __init__(self,encoder,decoder,embed_dim,tgt_vocab_size):
super(Transformer,self).__init__()
self.encoder=encoder
self.decoder=decoder
self.fc=nn.Linear(embed_dim,tgt_vocab_size)
def forward(self,enc_inputs,dec_inputs):
'''
enc_inputs:[batch_size,src_len]
dec_inputs:[batch_size,tgt_len]
'''
enc_outputs,enc_attn_all=self.encoder(enc_inputs)
dec_outputs,dec_attn_all,enc_dec_attn_all=self.decoder(dec_inputs,enc_inputs,enc_outputs)
dec_outputs=self.fc(dec_outputs)
dec_outputs=dec_outputs.view(-1,dec_outputs.size(-1))
return dec_outputs,dec_attn_all,enc_dec_attn_all,enc_attn_all
multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEAD,KEY_DIM,VALUE_DIM)
feedforward = FeedForward(EMBED_DIM,FF_DIM)
enc_pos_embed = get_positional_encoding(SRC_LEN,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
dec_pos_embed = get_positional_encoding(TGT_LEN,EMBED_DIM)
decoderlayer = DecoderLayer(multiheadattention,feedforward)
encoder = Encoder(SRC_VOCAB_SIZE,EMBED_DIM,enc_pos_embed,encoderlayer,NUM_LAYER)
decoder = Decoder(TGT_VOCAB_SIZE,EMBED_DIM,dec_pos_embed,decoderlayer,NUM_LAYER)
model = Transformer(encoder,decoder,EMBED_DIM,TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
for epoch in range(30):
for enc_inputs, dec_inputs, dec_outputs in loader:
'''
enc_inputs: [batch_size, src_len]
dec_inputs: [batch_size, tgt_len]
dec_outputs: [batch_size, tgt_len]
'''
# enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)
# outputs: [batch_size * tgt_len, tgt_vocab_size]
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
loss = criterion(outputs, dec_outputs.view(-1))
print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()