B 站视频讲解
本文主要介绍一下如何使用 PyTorch 复现 Transformer,实现简单的机器翻译任务。请先花上 15 分钟阅读我的这篇文章 Transformer 详解,再来看本文,方能达到醍醐灌顶,事半功倍的效果

数据预处理

这里我并没有用什么大型的数据集,而是手动输入了两对德语→英语的句子,还有每个字的索引也是我手动硬编码上去的,主要是为了降低代码阅读难度,我希望读者能更关注模型实现的部分

  1. #导入相关的库
  2. import torch
  3. import torch.nn as nn
  4. import torch.utils.data as Data
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import math
  8. import torch.optim as optim
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # enc_input                dec_input            dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
SRC_VOCAB_SIZE = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
TGT_VOCAB_SIZE = len(tgt_vocab)

SRC_LEN = 5 # enc_input max sequence length
TGT_LEN = 6 # dec_input(=dec_output) max sequence length

def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

        #enc_inputs:[2,5]
        #dec_inputs:[2,6]
        #dec_output:[2,6]
        enc_inputs.extend(enc_input)
        dec_inputs.extend(dec_input)
        dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
   def __init__(self, enc_inputs, dec_inputs, dec_outputs):
       super(MyDataSet, self).__init__()
       self.enc_inputs = enc_inputs
       self.dec_inputs = dec_inputs
       self.dec_outputs = dec_outputs

   def __len__(self):
       return self.enc_inputs.shape[0]

   def __getitem__(self, idx):
       return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

模型参数

下面变量代表的含义依次是

  1. 字嵌入 & 位置嵌入的维度,这俩值是相同的,因此用一个变量就行了
  2. FeedForward 层隐藏神经元个数
  3. Q、K、V 向量的维度,其中 Q 与 K 的维度必须相等,V 的维度没有限制,不过为了方便起见,我都设为 64
  4. Encoder 和 Decoder 的个数
  5. 多头注意力中 head 的数量
    # Transformer Parameters
    EMBED_DIM = 512  # Embedding Size
    FF_DIM = 2048 # FeedForward dimension
    KEY_DIM = VALUE_DIM = 64  # dimension of K(=Q), V
    NUM_LAYER = 6  # number of Encoder of Decoder Layer
    NUM_HEAD = 8  # number of heads in Multi-Head Attention
    
    上面都比较简单,下面开始涉及到模型就比较复杂了,因此我会将模型拆分成以下几个部分进行讲解
  • Positional Encoding
  • Pad Mask(针对句子不够长,加了 pad,因此需要对 pad 进行 mask)
  • Subsequence Mask(Decoder input 不能看到未来时刻单词信息,因此需要 mask)
  • ScaledDotProductAttention(计算 context vector)
  • Multi-Head Attention
  • FeedForward Layer
  • Encoder Layer
  • Encoder
  • Decoder Layer
  • Decoder
  • Transformer

关于代码中的注释,如果值为src_len或者tgt_len的,我一定会写清楚,但是有些函数或者类,Encoder 和 Decoder 都有可能调用,因此就不能确定究竟是src_len还是tgt_len,对于不确定的,我会记作seq_len

Positional Encoding

def get_positional_encoding(max_seq_len, embed_dim):
    # 初始化一个positional encoding
    # embed_dim: 字嵌入的维度
    # max_seq_len: 最大的序列长度
    positional_encoding = np.array([
        [pos / np.power(10000, 2 * (i//2) / embed_dim) for i in range(embed_dim)]
        if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])

    positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2])  # dim 2i 偶数
    positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2])  # dim 2i+1 奇数
    return torch.FloatTensor(positional_encoding)

这段代码并不复杂,传入的参数分别是字库的大小,以及位置编码的维度。最终返回的参数,维度也即 [n_position, d_model],和字编码的大小一摸一样

Pad Mask

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

由于在 Encoder 和 Decoder 中都需要进行 mask 操作,因此就无法确定这个函数的参数中seq_len的值,如果是在 Encoder 中调用的,seq_len就等于src_len;如果是在 Decoder 中调用的,seq_len就有可能等于src_len,也有可能等于tgt_len(因为 Decoder 有两次 mask)
这个函数最核心的一句代码是seq_k.data.eq(0),这句的作用是返回一个大小和seq_k一样的 tensor,只不过里面的值只有 True 和 False。如果seq_k某个位置的值等于 0,那么对应位置就是 True,否则即为 False。举个例子,输入为seq_data = [1, 2, 3, 4, 0]seq_data.data.eq(0)就会返回[False, False, False, False, True]
剩下的代码主要是扩展维度,强烈建议读者打印出来,看看最终返回的数据是什么样子

Subsequence Mask

def get_attn_subsequence_mask(seq):
    "seq:[batch_size,dec_seq_len]"
    attn_shape=[seq.shape[0],seq.shape[1],seq.shape[1]]
    subsequence_mask=np.triu(np.ones(attn_shape),1)
    subsequence_mask=torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask

Subsequence Mask 只有 Decoder 会用到,主要作用是屏蔽未来时刻单词的信息。首先通过np.ones()生成一个全 1 的方阵,然后通过np.triu()生成一个上三角矩阵,下图是np.triu()用法
pytorch实现transformer - 图1

ScaledDotProductAttention

def ScaledDotProductAttention(Q,K,V,attn_mask):

    '''
    Q:[batch_size,head,len_q,d_k]
    K:[batch_size,head,len_k,d_k]
    V:[batch_size,head,len_v,d_k]
    attn_mask:[batch_size,head,len_q,len_k]
    '''
    scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3))
    scores.masked_fill_(attn_mask,-1e9)

    attn=torch.nn.Softmax(-1)(scores)
    context=torch.matmul(attn,V)

    return context,attn

这里要做的是,通过QK计算出scores,然后将scoresV相乘,得到每个单词的 context vector
第一步是将QK的转置相乘没什么好说的,相乘之后得到的scores还不能立刻进行 softmax,需要和attn_mask相加,把一些需要屏蔽的信息屏蔽掉,attn_mask是一个仅由 True 和 False 组成的 tensor,并且一定会保证attn_maskscores的维度四个值相同(不然无法做对应位置相加)
mask 完了之后,就可以对scores进行 softmax 了。然后再与V相乘,得到context

MultiHeadAttention

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_head,query_dim,value_dim):
        super(MultiHeadAttention,self).__init__()
        self.head=num_head
        self.query=query_dim
        self.value=value_dim
        self.W_Q=nn.Linear(embed_dim,num_head*query_dim,bias=True)
        self.W_K=nn.Linear(embed_dim,num_head*query_dim,bias=True)
        self.W_V=nn.Linear(embed_dim,num_head*value_dim,bias=True)
        self.fc=nn.Linear(num_head*value_dim,embed_dim)

    def forward(self,inputs_query,inputs_key,inputs_value,attn_mask):
        '''
        inputs_query:[batch_size,len_q,embed_dim]
        inputs_key:[batch_size,len_k,embed_dim]
        inputs_value:[batch_size,len_v,embed_dim]
        attn_mask:[batch_size,len_q,len_k]
        '''
        residual,batch_size=inputs_query,inputs_query.size(0)

        #Q:[batch_size,num_head,len_q,query_dim]
        Q=self.W_Q(inputs_query).view(batch_size,-1,self.head,self.query).transpose(1,2)
        #K:[batch_size,num_head,len_k,query_dim]
        K=self.W_K(inputs_key).view(batch_size,-1,self.head,self.query).transpose(1,2)
        #V:[batch_size,num_head,len_v,value_dim]
        V=self.W_V(inputs_value).view(batch_size,-1,self.head,self.value).transpose(1,2)

        #attn_mask:[batch_size,num_head,len_q,len_k]
        attn_mask=attn_mask.unsqueeze(1).repeat(1,self.head,1,1)
        #context:[batch_size,num_head,len_q,value_dim]
        context,attn=ScaledDotProductAttention(Q,K,V,attn_mask)
        context=context.transpose(1,2).contiguous().view(batch_size,-1,self.head*self.value)
        output=self.fc(context)
        return nn.LayerNorm(output.size(2))(output+residual),attn

完整代码中一定会有三处地方调用MultiHeadAttention(),Encoder Layer 调用一次,传入的input_Qinput_Kinput_V全部都是enc_inputs;Decoder Layer 中两次调用,第一次传入的全是dec_inputs,第二次传入的分别是dec_outputsenc_outputsenc_outputs

FeedForward Layer

class FeedForward(nn.Module):
    def __init__(self,embed_dim,feedforward_dim):
        super(FeedForward,self).__init__()
        self.fc=nn.Sequential(
            nn.Linear(embed_dim,feedforward_dim,bias=True),
            nn.ReLU(),
            nn.Linear(feedforward_dim,embed_dim,bias=True)
        )

    def forward(self,inputs):
        "inputs:[batch_size,len_q,embed_size]"
        residual=inputs
        outputs=self.fc(inputs)
        return nn.LayerNorm(inputs.size(2))(residual + outputs)#size:[batch_size,len_q,embed_size]

这段代码非常简单,就是做两次线性变换,残差连接后再跟一个 Layer Norm

Encoder Layer

class EncoderLayer(nn.Module):
    def __init__(self,multiheadattention,feedforward):
        super(EncoderLayer,self).__init__()
        self.enc_self_attention=multiheadattention
        self.feedforward=feedforward
    def forward(self,enc_inputs,enc_attn_mask):
        '''
        enc_inputs:[batch_size,seq_len,embed_size]
        enc_attn_mask:[batch_size,len_q,len_k]
        '''
        #outputs:[batch_size,seq_len,embed_size]
        #attn:[batch_size,len_q,len_k]
        outputs,attn=self.enc_self_attention(enc_inputs,enc_inputs,enc_inputs,enc_attn_mask)
        outputs=self.feedforward(outputs)
        return outputs,attn

将上述组件拼起来,就是一个完整的 Encoder Layer

Encoder

class Encoder(nn.Module):
    def __init__(self,src_vocab_size,embed_dim,pos_embed,encoderlayer,num_layer):
        super(Encoder,self).__init__()
        self.src_embed=nn.Embedding(src_vocab_size,embed_dim)
        self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
        self.layer=nn.ModuleList([encoderlayer for _ in range(num_layer)])

    def forward(self,enc_inputs):
        "enc_inputs:[batch_size,seq_len]"
        seq_len = enc_inputs.size(1)
        pos = torch.LongTensor([i for i in range(seq_len)])
        pos = pos.unsqueeze(0).expand_as(enc_inputs)

        word_embed=self.src_embed(enc_inputs)
        pos_embed=self.pos_embed(pos)
        input_embed=word_embed+pos_embed
        enc_attn_mask=get_attn_pad_mask(enc_inputs,enc_inputs)
        enc_attn=[]
        for layer in self.layer:
            outputs,attn=layer(input_embed,enc_attn_mask)
            input_embed = outputs
            enc_attn.append(attn) 
        return outputs,enc_attn

使用nn.ModuleList()里面的参数是列表,列表里面存了n_layers个 Encoder Layer
由于我们控制好了 Encoder Layer 的输入和输出维度相同,所以可以直接用个 for 循环以嵌套的方式,将上一次 Encoder Layer 的输出作为下一次 Encoder Layer 的输入

Decoder Layer

class DecoderLayer(nn.Module):
    def __init__(self,multiheadattention,feedforward):
        super(DecoderLayer,self).__init__()
        self.dec_self_attn=multiheadattention
        self.enc_dec_self_attn=multiheadattention
        self.feedforward=feedforward

    def forward(self,dec_inputs,enc_outputs,dec_attn_mask,enc_dec_attn_mask):
        '''
        dec_inputs:[batch_size,tgt_len,embed_dim]
        enc_outputs:[batch_size,src_len,embed_dim]
        dec_attn_mask:[batch_size,tgt_len,tgt_len]
        end_dec_attn_mask:[batch_size,tgt_len,src_len]
        '''
        #dec_outputs:[batch_size,tgt_len,embed_dim]
        dec_outputs,dec_attn=self.dec_self_attn(dec_inputs,dec_inputs,dec_inputs,dec_attn_mask)
        dec_outputs,enc_dec_attn=self.enc_dec_self_attn(dec_outputs,enc_outputs,enc_outputs,enc_dec_attn_mask)
        dec_outputs=self.feedforward(dec_outputs)
        return dec_outputs,dec_attn,enc_dec_attn

在 Decoder Layer 中会调用两次MultiHeadAttention,第一次是计算 Decoder Input 的 self-attention,得到输出dec_outputs。然后将dec_outputs作为生成 Q 的元素,enc_outputs作为生成 K 和 V 的元素,再调用一次MultiHeadAttention,得到的是 Encoder 和 Decoder Layer 之间的 context vector。最后将dec_outptus做一次维度变换,然后返回

Decoder

class Decoder(nn.Module):
    def __init__(self,tgt_vocab_size,embed_dim,pos_embed,decoderlayer,num_layer):
        super(Decoder,self).__init__()
        self.tgt_embed=nn.Embedding(tgt_vocab_size,embed_dim)
        self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
        self.layers=nn.ModuleList([decoderlayer for _ in range(num_layer)])

    def forward(self,dec_inputs,enc_inputs,enc_outputs):
        "dec_inputs:[batch_size,tgt_len]"
        seq_len = dec_inputs.size(1)
        pos = torch.LongTensor([i for i in range(seq_len)])
        pos = pos.unsqueeze(0).expand_as(dec_inputs)

        tgt_embed=self.tgt_embed(dec_inputs)
        pos_embed=self.pos_embed(pos)
        input_embed=tgt_embed+pos_embed

        dec_self_attn_pad_mask=get_attn_pad_mask(dec_inputs,dec_inputs)
        dec_subsequence_attn_mask=get_attn_subsequence_mask(dec_inputs)
        dec_self_attn_mask=dec_self_attn_pad_mask+dec_subsequence_attn_mask
        dec_self_attn_mask=torch.gt(dec_self_attn_mask,0)
        enc_dec_self_attn_mask=get_attn_pad_mask(dec_inputs,enc_inputs)

        dec_attn_all=[]
        enc_dec_attn_all=[]
        for layer in self.layers:
            outputs,dec_attn,enc_dec_attn=layer(input_embed,enc_outputs,dec_self_attn_mask,enc_dec_self_attn_mask)
            input_embed = outputs
            dec_attn_all.append(dec_attn)
            enc_dec_attn_all.append(enc_dec_attn)

        return outputs,dec_attn_all,enc_dec_attn_all

Decoder 中不仅要把 “pad”mask 掉,还要 mask 未来时刻的信息,因此就有了下面这三行代码,其中torch.gt(a, value)的意思是,将 a 中各个位置上的元素和 value 比较,若大于 value,则该位置取 1,否则取 0

dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]

Transformer

class Transformer(nn.Module):
    def __init__(self,encoder,decoder,embed_dim,tgt_vocab_size):
        super(Transformer,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.fc=nn.Linear(embed_dim,tgt_vocab_size)

    def forward(self,enc_inputs,dec_inputs):
        '''
        enc_inputs:[batch_size,src_len]
        dec_inputs:[batch_size,tgt_len]
        '''
        enc_outputs,enc_attn_all=self.encoder(enc_inputs)
        dec_outputs,dec_attn_all,enc_dec_attn_all=self.decoder(dec_inputs,enc_inputs,enc_outputs)
        dec_outputs=self.fc(dec_outputs)
        dec_outputs=dec_outputs.view(-1,dec_outputs.size(-1))
        return dec_outputs,dec_attn_all,enc_dec_attn_all,enc_attn_all

Transformer 主要就是调用 Encoder 和 Decoder。最后返回dec_logits的维度是 [batch_size tgt_len, tgt_vocab_size],可以理解为,一个句子,这个句子有 batch_sizetgt_len 个单词,每个单词有 tgt_vocab_size 种情况,取概率最大者

模型 & 损失函数 & 优化器

multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEAD,KEY_DIM,VALUE_DIM)
feedforward = FeedForward(EMBED_DIM,FF_DIM)

enc_pos_embed = get_positional_encoding(SRC_LEN,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
dec_pos_embed = get_positional_encoding(TGT_LEN,EMBED_DIM)
decoderlayer = DecoderLayer(multiheadattention,feedforward)

encoder = Encoder(SRC_VOCAB_SIZE,EMBED_DIM,enc_pos_embed,encoderlayer,NUM_LAYER)
decoder = Decoder(TGT_VOCAB_SIZE,EMBED_DIM,dec_pos_embed,decoderlayer,NUM_LAYER)

model = Transformer(encoder,decoder,EMBED_DIM,TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

这里的损失函数里面我设置了一个参数ignore_index=0,因为 “pad” 这个单词的索引为 0,这样设置以后,就不会计算 “pad” 的损失(因为本来 “pad” 也没有意义,不需要计算),关于这个参数更详细的说明,可以看我这篇文章的最下面,稍微提了一下

训练

for epoch in range(30):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      # enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))

      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

测试

enc_inputs, dec_inputs, _ = next(iter(loader))
predict, _, _, _ = model(enc_inputs[0].view(1, -1), dec_inputs[0].view(1, -1)) # model(enc_inputs[0].view(1, -1), greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(enc_inputs[0], '->', [idx2word[n.item()] for n in predict.squeeze()])

完整代码

#导入相关的库
import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
import math
import torch.optim as optim

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # enc_input                dec_input            dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
SRC_VOCAB_SIZE = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
TGT_VOCAB_SIZE = len(tgt_vocab)

SRC_LEN = 5 # enc_input max sequence length
TGT_LEN = 6 # dec_input(=dec_output) max sequence length

def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

        #enc_inputs:[2,5]
        #dec_inputs:[2,6]
        #dec_output:[2,6]
        enc_inputs.extend(enc_input)
        dec_inputs.extend(dec_input)
        dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
   def __init__(self, enc_inputs, dec_inputs, dec_outputs):
       super(MyDataSet, self).__init__()
       self.enc_inputs = enc_inputs
       self.dec_inputs = dec_inputs
       self.dec_outputs = dec_outputs

   def __len__(self):
       return self.enc_inputs.shape[0]

   def __getitem__(self, idx):
       return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

# Transformer Parameters
EMBED_DIM = 512  # Embedding Size
FF_DIM = 2048 # FeedForward dimension
KEY_DIM = VALUE_DIM = 64  # dimension of K(=Q), V
NUM_LAYER = 6  # number of Encoder of Decoder Layer
NUM_HEAD = 8  # number of heads in Multi-Head Attention

def get_positional_encoding(max_seq_len, embed_dim):
    # 初始化一个positional encoding
    # embed_dim: 字嵌入的维度
    # max_seq_len: 最大的序列长度
    positional_encoding = np.array([
        [pos / np.power(10000, 2 * (i//2) / embed_dim) for i in range(embed_dim)]
        if pos != 0 else np.zeros(embed_dim) for pos in range(max_seq_len)])

    positional_encoding[1:, 0::2] = np.sin(positional_encoding[1:, 0::2])  # dim 2i 偶数
    positional_encoding[1:, 1::2] = np.cos(positional_encoding[1:, 1::2])  # dim 2i+1 奇数
    return torch.FloatTensor(positional_encoding)

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq):
    "seq:[batch_size,dec_seq_len]"
    attn_shape=[seq.shape[0],seq.shape[1],seq.shape[1]]
    subsequence_mask=np.triu(np.ones(attn_shape),1)
    subsequence_mask=torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask

def ScaledDotProductAttention(Q,K,V,attn_mask):

    '''
    Q:[batch_size,head,len_q,d_k]
    K:[batch_size,head,len_k,d_k]
    V:[batch_size,head,len_v,d_k]
    attn_mask:[batch_size,head,len_q,len_k]
    '''
    scores=torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(Q.size(3))
    scores.masked_fill_(attn_mask,-1e9)

    attn=torch.nn.Softmax(-1)(scores)
    context=torch.matmul(attn,V)

    return context,attn

class MultiHeadAttention(nn.Module):
    def __init__(self,embed_dim,num_head,query_dim,value_dim):
        super(MultiHeadAttention,self).__init__()
        self.head=num_head
        self.query=query_dim
        self.value=value_dim
        self.W_Q=nn.Linear(embed_dim,num_head*query_dim,bias=True)
        self.W_K=nn.Linear(embed_dim,num_head*query_dim,bias=True)
        self.W_V=nn.Linear(embed_dim,num_head*value_dim,bias=True)
        self.fc=nn.Linear(num_head*value_dim,embed_dim)

    def forward(self,inputs_query,inputs_key,inputs_value,attn_mask):
        '''
        inputs_query:[batch_size,len_q,embed_dim]
        inputs_key:[batch_size,len_k,embed_dim]
        inputs_value:[batch_size,len_v,embed_dim]
        attn_mask:[batch_size,len_q,len_k]
        '''
        residual,batch_size=inputs_query,inputs_query.size(0)

        #Q:[batch_size,num_head,len_q,query_dim]
        Q=self.W_Q(inputs_query).view(batch_size,-1,self.head,self.query).transpose(1,2)
        #K:[batch_size,num_head,len_k,query_dim]
        K=self.W_K(inputs_key).view(batch_size,-1,self.head,self.query).transpose(1,2)
        #V:[batch_size,num_head,len_v,value_dim]
        V=self.W_V(inputs_value).view(batch_size,-1,self.head,self.value).transpose(1,2)

        #attn_mask:[batch_size,num_head,len_q,len_k]
        attn_mask=attn_mask.unsqueeze(1).repeat(1,self.head,1,1)
        #context:[batch_size,num_head,len_q,value_dim]
        context,attn=ScaledDotProductAttention(Q,K,V,attn_mask)
        context=context.transpose(1,2).contiguous().view(batch_size,-1,self.head*self.value)
        output=self.fc(context)
        return nn.LayerNorm(output.size(2))(output+residual),attn


class FeedForward(nn.Module):
    def __init__(self,embed_dim,feedforward_dim):
        super(FeedForward,self).__init__()
        self.fc=nn.Sequential(
            nn.Linear(embed_dim,feedforward_dim,bias=True),
            nn.ReLU(),
            nn.Linear(feedforward_dim,embed_dim,bias=True)
        )

    def forward(self,inputs):
        "inputs:[batch_size,len_q,embed_size]"
        residual=inputs
        outputs=self.fc(inputs)
        return nn.LayerNorm(inputs.size(2))(residual + outputs)#size:[batch_size,len_q,embed_size]

class EncoderLayer(nn.Module):
    def __init__(self,multiheadattention,feedforward):
        super(EncoderLayer,self).__init__()
        self.enc_self_attention=multiheadattention
        self.feedforward=feedforward
    def forward(self,enc_inputs,enc_attn_mask):
        '''
        enc_inputs:[batch_size,seq_len,embed_size]
        enc_attn_mask:[batch_size,len_q,len_k]
        '''
        #outputs:[batch_size,seq_len,embed_size]
        #attn:[batch_size,len_q,len_k]
        outputs,attn=self.enc_self_attention(enc_inputs,enc_inputs,enc_inputs,enc_attn_mask)
        outputs=self.feedforward(outputs)
        return outputs,attn

class Encoder(nn.Module):
    def __init__(self,src_vocab_size,embed_dim,pos_embed,encoderlayer,num_layer):
        super(Encoder,self).__init__()
        self.src_embed=nn.Embedding(src_vocab_size,embed_dim)
        self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
        self.layer=nn.ModuleList([encoderlayer for _ in range(num_layer)])

    def forward(self,enc_inputs):
        "enc_inputs:[batch_size,seq_len]"
        seq_len = enc_inputs.size(1)
        pos = torch.LongTensor([i for i in range(seq_len)])
        pos = pos.unsqueeze(0).expand_as(enc_inputs)

        word_embed=self.src_embed(enc_inputs)
        pos_embed=self.pos_embed(pos)
        input_embed=word_embed+pos_embed
        enc_attn_mask=get_attn_pad_mask(enc_inputs,enc_inputs)
        enc_attn=[]
        for layer in self.layer:
            outputs,attn=layer(input_embed,enc_attn_mask)
            enc_attn.append(attn) 
        return outputs,enc_attn

class DecoderLayer(nn.Module):
    def __init__(self,multiheadattention,feedforward):
        super(DecoderLayer,self).__init__()
        self.dec_self_attn=multiheadattention
        self.enc_dec_self_attn=multiheadattention
        self.feedforward=feedforward

    def forward(self,dec_inputs,enc_outputs,dec_attn_mask,enc_dec_attn_mask):
        '''
        dec_inputs:[batch_size,tgt_len,embed_dim]
        enc_outputs:[batch_size,src_len,embed_dim]
        dec_attn_mask:[batch_size,tgt_len,tgt_len]
        end_dec_attn_mask:[batch_size,tgt_len,src_len]
        '''
        #dec_outputs:[batch_size,tgt_len,embed_dim]
        dec_outputs,dec_attn=self.dec_self_attn(dec_inputs,dec_inputs,dec_inputs,dec_attn_mask)
        dec_outputs,enc_dec_attn=self.enc_dec_self_attn(dec_outputs,enc_outputs,enc_outputs,enc_dec_attn_mask)
        dec_outputs=self.feedforward(dec_outputs)
        return dec_outputs,dec_attn,enc_dec_attn

class Decoder(nn.Module):
    def __init__(self,tgt_vocab_size,embed_dim,pos_embed,decoderlayer,num_layer):
        super(Decoder,self).__init__()
        self.tgt_embed=nn.Embedding(tgt_vocab_size,embed_dim)
        self.pos_embed=nn.Embedding.from_pretrained(pos_embed,freeze=True)
        self.layers=nn.ModuleList([decoderlayer for _ in range(num_layer)])

    def forward(self,dec_inputs,enc_inputs,enc_outputs):
        "dec_inputs:[batch_size,tgt_len]"
        seq_len = dec_inputs.size(1)
        pos = torch.LongTensor([i for i in range(seq_len)])
        pos = pos.unsqueeze(0).expand_as(dec_inputs)

        tgt_embed=self.tgt_embed(dec_inputs)
        pos_embed=self.pos_embed(pos)
        input_embed=tgt_embed+pos_embed

        dec_self_attn_pad_mask=get_attn_pad_mask(dec_inputs,dec_inputs)
        dec_subsequence_attn_mask=get_attn_subsequence_mask(dec_inputs)
        dec_self_attn_mask=dec_self_attn_pad_mask+dec_subsequence_attn_mask
        dec_self_attn_mask=torch.gt(dec_self_attn_mask,0)
        enc_dec_self_attn_mask=get_attn_pad_mask(dec_inputs,enc_inputs)

        dec_attn_all=[]
        enc_dec_attn_all=[]
        for layer in self.layers:
            outputs,dec_attn,enc_dec_attn=layer(input_embed,enc_outputs,dec_self_attn_mask,enc_dec_self_attn_mask)
            dec_attn_all.append(dec_attn)
            enc_dec_attn_all.append(enc_dec_attn)

        return outputs,dec_attn_all,enc_dec_attn_all

class Transformer(nn.Module):
    def __init__(self,encoder,decoder,embed_dim,tgt_vocab_size):
        super(Transformer,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.fc=nn.Linear(embed_dim,tgt_vocab_size)

    def forward(self,enc_inputs,dec_inputs):
        '''
        enc_inputs:[batch_size,src_len]
        dec_inputs:[batch_size,tgt_len]
        '''
        enc_outputs,enc_attn_all=self.encoder(enc_inputs)
        dec_outputs,dec_attn_all,enc_dec_attn_all=self.decoder(dec_inputs,enc_inputs,enc_outputs)
        dec_outputs=self.fc(dec_outputs)
        dec_outputs=dec_outputs.view(-1,dec_outputs.size(-1))
        return dec_outputs,dec_attn_all,enc_dec_attn_all,enc_attn_all

multiheadattention = MultiHeadAttention(EMBED_DIM,NUM_HEAD,KEY_DIM,VALUE_DIM)
feedforward = FeedForward(EMBED_DIM,FF_DIM)

enc_pos_embed = get_positional_encoding(SRC_LEN,EMBED_DIM)
encoderlayer = EncoderLayer(multiheadattention,feedforward)
dec_pos_embed = get_positional_encoding(TGT_LEN,EMBED_DIM)
decoderlayer = DecoderLayer(multiheadattention,feedforward)

encoder = Encoder(SRC_VOCAB_SIZE,EMBED_DIM,enc_pos_embed,encoderlayer,NUM_LAYER)
decoder = Decoder(TGT_VOCAB_SIZE,EMBED_DIM,dec_pos_embed,decoderlayer,NUM_LAYER)

model = Transformer(encoder,decoder,EMBED_DIM,TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

for epoch in range(30):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      # enc_inputs, dec_inputs, dec_outputs = enc_inputs.to(device), dec_inputs.to(device), dec_outputs.to(device)
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))

      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()