文本主要介绍一下如何使用 PyTorch 复现 Seq2Seq(with Attention),实现简单的机器翻译任务。首先我会向大家介绍seq2seq与注意力机制的原理,然后介绍seq2seq(with attention)模型各个模块的实现,最后将给出完整代码。

1. seq2seq与注意力机制的原理

1.1 编码器与解码器

编码器和解码器分别对应输入序列和输出序列的两个循环神经网络。我们通常会在输入序列和输出序列后面分别附上一个特殊字符’‘(end of sequence)表示序列的终止。在测试模型时,一旦输出’‘就终止当前的输出序列。

1.1.1 编码器

编码器的作用是把一个不定长的输入序列转化成一个定长的背景词向量 seq2seq机器翻译 - 图2。该背景词向量包含了输入序列的信息。常用的编码器是循环神经网络。

首先回顾以下循环神经网络的知识。假设循环神经网络单元为 seq2seq机器翻译 - 图3,在seq2seq机器翻译 - 图4时刻的输入为seq2seq机器翻译 - 图5,seq2seq机器翻译 - 图6。假设seq2seq机器翻译 - 图7是单个输出在嵌入层的结果,例如seq2seq机器翻译 - 图8对应的seq2seq机器翻译 - 图9向量 seq2seq机器翻译 - 图10与嵌入层参数矩阵seq2seq机器翻译 - 图11的乘积 seq2seq机器翻译 - 图12。隐藏层变量
seq2seq机器翻译 - 图13
编码器的背景向量
seq2seq机器翻译 - 图14
一个简单的背景向量可以认为是该网络最终时刻的隐藏层变量seq2seq机器翻译 - 图15,即seq2seq机器翻译 - 图16。我们将这里的循环神经网络叫做编码器。

1.1.2 双向循坏神经网络

编码器的输入既可以是正向传递,也可以是反向传递的。如果输入序列是 seq2seq机器翻译 - 图17,在正向传递中,隐藏层变量
seq2seq机器翻译 - 图18
而反向传递过程中,隐藏层变量的计算变为
seq2seq机器翻译 - 图19
当我们希望编码器的输入既包含正向传递信息又包含反向传递信息时,我们可以使用双向循环神经网络。例如,给定输入序列 seq2seq机器翻译 - 图20,按正向传递,它们在循环神经网络中的隐藏层变量分别是
seq2seq机器翻译 - 图21;按反向传播,它们在循环神经网络中的隐藏层变量分别是 seq2seq机器翻译 - 图22。在双向循环神经网络中,时刻 seq2seq机器翻译 - 图23 的隐藏层变量是将 seq2seq机器翻译 - 图24seq2seq机器翻译 - 图25拼接起来。例如:

  1. import torch
  2. h_forward = torch.Tensor([1, 2])
  3. h_backward = torch.Tensor([3, 4])
  4. h_bi = torch.cat((h_forward, h_backward), dim=0)
  5. # tensor([1., 2., 3., 4.])

1.1.3 解码器

编码器最终输出了一个背景向量 seq2seq机器翻译 - 图26,该背景向量整合了输入序列 seq2seq机器翻译 - 图27,假设训练数据中的输出序列是 seq2seq机器翻译 - 图28,我们希望表示每个seq2seq机器翻译 - 图29 时刻输出的向量,既取决于之前的输出又取决于背景向量。因为,我们可以最大化输出序列的联合概率
seq2seq机器翻译 - 图30
并得到该输出序列的损失函数
seq2seq机器翻译 - 图31
为此,我们使用另一个循环神经网络作为解码器。解码器使用函数 seq2seq机器翻译 - 图32来表示单个输出 seq2seq机器翻译 - 图33的概率
seq2seq机器翻译 - 图34
其中的 seq2seq机器翻译 - 图35seq2seq机器翻译 - 图36时刻的解码器的隐藏层变量。该隐藏层变量
seq2seq机器翻译 - 图37
其中函数 seq2seq机器翻译 - 图38 是循环神经网络单元需要注意的是,编码器和解码器通常会使用多层循环神经网络

1.2 注意力机制

1.2.1 图解注意力机制

以前讲BERT的时候,讲过注意力机制,今天我们通过图片再来了解一下。下图是一个 Encoder 架构,seq2seq机器翻译 - 图39从值上来说与 seq2seq机器翻译 - 图40是相等的,只不过这里换了个名字。

首先我们需要将 seq2seq机器翻译 - 图41和所有的 seq2seq机器翻译 - 图42计算一个 “相关性”,比方说计算 seq2seq机器翻译 - 图43seq2seq机器翻译 - 图44之间的相关性计算得seq2seq机器翻译 - 图45
MfpLm1EAvh2xSwG.png
计算得到 m 个相关性 seq2seq机器翻译 - 图47之后,将这些值与 seq2seq机器翻译 - 图48进行加权平均,即:
seq2seq机器翻译 - 图49
我们可以直观的感受一下这样做有什么作用,对于那些值比较大的 seq2seq机器翻译 - 图50,最终 seq2seq机器翻译 - 图51中也会有一大部分来自于 seq2seq机器翻译 - 图52seq2seq机器翻译 - 图53实际上考虑到了所有时刻的 seq2seq机器翻译 - 图54,只不过对于某些时刻可能关注的更多,而某些时刻关注的更少,这就是注意力机制。
L8IwarnXJt4EbxD.png
之后将seq2seq机器翻译 - 图56 作为seq2seq机器翻译 - 图57时刻 Decoder 的输入,计算seq2seq机器翻译 - 图58,然后再计算 seq2seq机器翻译 - 图59与所有 seq2seq机器翻译 - 图60之间的相关性seq2seq机器翻译 - 图61
a5Nuj1F6y4rlVpT.png
同样的,将新计算得到的 seq2seq机器翻译 - 图63seq2seq机器翻译 - 图64做加权平均,得到新的 context vector seq2seq机器翻译 - 图65
Sc8ZNktdAW6elDs.png
重复上述步骤,知道Decoder结束。
yqm7WM4tKQpVc6B.png
到这里实际上整个 Seq2Seq(with Attention) 就讲完了,但是其中还有一些细节,比方说,align()函数怎么设计?seq2seq机器翻译 - 图68如何应用到 Decoder 中?下面一一解释。

1.2.2 align()函数如何设计?

有两种方法,在最初的论文,即Bahdanau 的论文中,他的设计方式如下图所示:RY2ZSFDgjmbWIeM.png
现在比较主流的,同时也是 Transformer 结构使用的方法如下所示:oYlvH6a5eNqUBmL.png

1.2.3 seq2seq机器翻译 - 图71如何应用到 Decoder 中?

7kT3SdcVuOQeNAp.png

2. seq2seq各个模块的实现

在这里我们将实现四个模块:Encoder,Decoder,Attention,Seq2seq

2.1 Encoder

Encoder 我是用的单层双向 GRU
g6DaQUHPlG9AqKe.png
双向 GRU 的隐藏状态输出由两个向量拼接而成,例如 seq2seq机器翻译 - 图74,seq2seq机器翻译 - 图75…… 所有时刻的最后一层隐藏状态就构成了 GRU 的 output
seq2seq机器翻译 - 图76
假设这是个 m 层 GRU,那么最后一个时刻所有层中的隐藏状态就构成了 GRU 的 final hidden states
seq2seq机器翻译 - 图77
其中
seq2seq机器翻译 - 图78
所以
seq2seq机器翻译 - 图79
我们需要的是 seq2seq机器翻译 - 图80 的最后一层输出(包括正向和反向),因此我们可以通过seq2seq机器翻译 - 图81seq2seq机器翻译 - 图82取出最后一层的 seq2seq机器翻译 - 图83,将它们拼接起来记作 seq2seq机器翻译 - 图84最后一个细节之处在于,
seq2seq机器翻译 - 图85的维度是[batchsize, en_hid_dim*2],即便是没有 Attention 机制,将 ![](https://cdn.nlark.com/yuque/__latex/6d155a8ec86cc6633458655c91f23d08.svg#card=math&code=s%7B0%7D&height=17&width=19)作为 Decoder 的初始隐藏状态也不对,因为维度不匹配,Decoder 的初始隐藏状态是三维的,而现在我们的 seq2seq机器翻译 - 图86是二维的,因此需要将 seq2seq机器翻译 - 图87的维度转为三维,并且还要调整各个维度上的值。首先我通过一个全连接神经网络,将
seq2seq机器翻译 - 图88的维度变为[batch_size, dec_hid_dim]

Encoder 的细节就这么多,下面直接上代码,我的代码风格是,注释在上,代码在下

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super(Encoder,self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src): 
        '''
        src = [src_len, batch_size]
        '''
        src = src.transpose(0, 1) # src = [batch_size, src_len]
        embedded = self.dropout(self.embedding(src)).transpose(0, 1) # embedded = [src_len, batch_size, emb_dim]

        # enc_output = [src_len, batch_size, hid_dim * num_directions]
        # enc_hidden = [n_layers * num_directions, batch_size, hid_dim]
        enc_output, enc_hidden = self.rnn(embedded) # if h_0 is not give, it will be set 0 acquiescently

        # enc_hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # enc_output are always from the last layer

        # enc_hidden [-2, :, : ] is the last of the forwards RNN 
        # enc_hidden [-1, :, : ] is the last of the backwards RNN

        # initial decoder hidden is final hidden state of the forwards and backwards 
        # encoder RNNs fed through a linear layer
        # s = [batch_size, dec_hid_dim]
        s = torch.tanh(self.fc(torch.cat((enc_hidden[-2,:,:], enc_hidden[-1,:,:]), dim = 1)))

        return enc_output, s

2.2 Attention

attention 无非就是三个公式
seq2seq机器翻译 - 图89
seq2seq机器翻译 - 图90
seq2seq机器翻译 - 图91
其中 seq2seq机器翻译 - 图92指的就是 Encoder 中的变量s,seq2seq机器翻译 - 图93指的就是 Encoder 中的变量enc_output,seq2seq机器翻译 - 图94其实就是一个简单的全连接神经网络

我们可以从最后一个公式反推各个变量的维度是什么,或者维度有什么要求,首先 seq2seq机器翻译 - 图95的维度应该是[batchsize, src_len],这是毋庸置疑的,那么 ![](https://cdn.nlark.com/yuque/__latex/b8a3f8e927295af1dd2c44c0c8c57298.svg#card=math&code=%5Ctilde%7Ba%7D%7Bt%7D&height=22&width=18)的维度也应该是[batchsize, src_len],或者![](https://cdn.nlark.com/yuque/__latex/b8a3f8e927295af1dd2c44c0c8c57298.svg#card=math&code=%5Ctilde%7Ba%7D%7Bt%7D&height=22&width=18) 是个三维的,但是某个维度值为 1,可以通过squeeze()变成两维的。这里我们先假设 seq2seq机器翻译 - 图96的维度是[batch_size, src_len, 1],等会儿我再解释为什么要这样假设

继续往上推,变量 seq2seq机器翻译 - 图97的维度就应该是[?, 1],?表示我暂时不知道它的值应该是多少。seq2seq机器翻译 - 图98的维度应该是[batch_size, src_len, ?]

现在已知 seq2seq机器翻译 - 图99的维度是[batchsize, src_len, enc_hid_dim*2],![](https://cdn.nlark.com/yuque/__latex/cc50d8301ddb3d833f2168022ebb1c67.svg#card=math&code=s%7Bt-1%7D&height=17&width=35)目前的维度是[batchsize, dec_hid_dim],这两个变量需要做拼接,送入全连接神经网络,因此我们首先需要将 ![](https://cdn.nlark.com/yuque/__latex/cc50d8301ddb3d833f2168022ebb1c67.svg#card=math&code=s%7Bt-1%7D&height=17&width=35)的维度变成[batch_size, src_len, dec_hid_dim],拼接之后的维度就变成[batch_size, src_len, enc_hid_dim*2+enc_hid_dim],于是 seq2seq机器翻译 - 图100 这个函数的输入输出值也就有了

attn = nn.Linear(enc_hid_dim*2+enc_hid_dim, ?)

到此为止,除了?部分的值不清楚,其它所有维度都推导出来了。现在我们回过头思考一下?设置成多少,好像其实并没有任何限制,所以我们可以设置?为任何值(在代码中我设置?为dec_hid_dim)
Attention 细节就这么多,下面给出代码

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention,self).__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, s, enc_output):

        # s = [batch_size, dec_hid_dim]
        # enc_output = [src_len, batch_size, enc_hid_dim * 2]

        batch_size = enc_output.shape[1]
        src_len = enc_output.shape[0]

        # repeat decoder hidden state src_len times
        # s = [batch_size, src_len, dec_hid_dim]
        # enc_output = [batch_size, src_len, enc_hid_dim * 2]
        s = s.unsqueeze(1).repeat(1, src_len, 1)
        enc_output = enc_output.transpose(0, 1)

        # energy = [batch_size, src_len, dec_hid_dim]
        energy = torch.tanh(self.attn(torch.cat((s, enc_output), dim = 2)))

        # attention = [batch_size, src_len]
        attention = self.v(energy).squeeze(2)

        return F.softmax(attention, dim=1)

2.3 Decoder

Decoder 我用的是单向单层 GRUzLtbUKCpMvm84SW.webp
Decoder 部分实际上也就是三个公式
seq2seq机器翻译 - 图102
seq2seq机器翻译 - 图103
seq2seq机器翻译 - 图104
首先在 Encoder 中最开始先调用一次 Attention,得到权重 seq2seq机器翻译 - 图105,它的维度是[batchsize, src_len],而 seq2seq机器翻译 - 图106是enc_ouput,它的维度是[src_len, batch_size, enc_hid_dim*2],它俩要相乘,同时应该保留batch_size这个维度,所以应该先将 ![](https://cdn.nlark.com/yuque/__latex/14e0df45900873845c67edfe0b8d2ae8.svg#card=math&code=a%7Bt%7D&height=17&width=18)扩展一维,然后调换一下 seq2seq机器翻译 - 图107维度的顺序,之后再按照 batch 相乘(即同一个 batch 内的矩阵相乘)

a = a.unsqueeze(1) # [batch_size, 1, src_len]
H = H.transpose(0, 1) # [batch_size, src_len, enc_hid_dim*2]
w = torch.bmm(a, h) # [batch_size, 1, enc_hid_dim*2]

由于GRU只能传入两个参数,所以需要将seq2seq机器翻译 - 图108和w整合为一个参数,实际上就是 Seq2Seq 类中的decinput变量,它的维度是[batch_size],因此先将 ![](https://cdn.nlark.com/yuque/__latex/58bfce52e8f1f9485a789e99d086c969.svg#card=math&code=y%7Bt%7D&height=17&width=17)扩展一个维度,再通过 WordEmbedding,这样他就变成[batchsize, 1, emb_dim]。最后对seq2seq机器翻译 - 图109 和 ![](https://cdn.nlark.com/yuque/__latex/b0d54f2de38a1511a07c26fd5fe3623e.svg#card=math&code=emb%28y%7Bt%7D%29&height=24&width=70)进行 concat

y = y.unsqueeze(1) # [batch_size, 1]
emb_y = self.emb(y) # [batch_size, 1, emb_dim]
rnn_input = torch.cat((emb_y, c), dim=2) # [batch_size, 1, emb_dim+enc_hid_dim*2]

seq2seq机器翻译 - 图110的维度是[batch_size, dec_hid_dim],所以应该先将其拓展一个维度

rnn_input = rnn_input.transpose(0, 1) # [1, batch_size, emb_dim+enc_hid_dim*2]
s = s.unsqueeze(1) # [batch_size, 1, dec_hid_dim]
# dec_output = [1, batch_size, dec_hid_dim]
# dec_hidden = [1, batch_size, dec_hid_dim] = s (new, is not s previously)
dec_output, dec_hidden = self.rnn(rnn_input, s)

最后一个公式,需要将三个变量全部拼接在一起,然后通过一个全连接神经网络,得到最终的预测。我们先分析下这个三个变量的维度,seq2seq机器翻译 - 图111的维度是[batch_size, 1, emb_dim],seq2seq机器翻译 - 图112的维度是[batch_size, 1, enc_hid_dim],seq2seq机器翻译 - 图113的维度是[1, batch_size, dec_hid_dim],因此我们可以像下面这样把他们全部拼接起来。

emd_y = emb_y.squeeze(1) # [batch_size, emb_dim]
c = w.squeeze(1) # [batch_size, enc_hid_dim*2]
s = s.squeeze(0) # [batch_size, dec_hid_dim]
fc_input = torch.cat((emb_y, c, s), dim=1) # [batch_size, enc_hid_dim*2+dec_hid_dim+emb_hid]

以上就是 Decoder 部分的细节,下面给出代码(上面的那些只是示例代码,和下面代码变量名可能不一样)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super(Decoder,self).__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_input, s, enc_output):

        # dec_input = [batch_size]
        # s = [batch_size, dec_hid_dim]
        # enc_output = [src_len, batch_size, enc_hid_dim * 2]

        dec_input = dec_input.unsqueeze(1) # dec_input = [batch_size, 1]

        embedded = self.dropout(self.embedding(dec_input)).transpose(0, 1) # embedded = [1, batch_size, emb_dim]

        # a = [batch_size, 1, src_len]  
        a = self.attention(s, enc_output).unsqueeze(1)

        # enc_output = [batch_size, src_len, enc_hid_dim * 2]
        enc_output = enc_output.transpose(0, 1)

        # c = [1, batch_size, enc_hid_dim * 2]
        c = torch.bmm(a, enc_output).transpose(0, 1)

        # rnn_input = [1, batch_size, (enc_hid_dim * 2) + emb_dim]
        rnn_input = torch.cat((embedded, c), dim = 2)

        # dec_output = [src_len(=1), batch_size, dec_hid_dim]
        # dec_hidden = [n_layers * num_directions, batch_size, dec_hid_dim]
        dec_output, dec_hidden = self.rnn(rnn_input, s.unsqueeze(0))

        # embedded = [batch_size, emb_dim]
        # dec_output = [batch_size, dec_hid_dim]
        # c = [batch_size, enc_hid_dim * 2]
        embedded = embedded.squeeze(0)
        dec_output = dec_output.squeeze(0)
        c = c.squeeze(0)

        # pred = [batch_size, output_dim]
        pred = self.fc_out(torch.cat((dec_output, c, embedded), dim = 1))

        return pred, dec_hidden.squeeze(0)

2.4 seq2seq(with attention)

传统 Seq2Seq 是直接将句子中每个词连续不断输入 Decoder 进行训练,而引入 Attention 机制之后,我需要能够人为控制一个词一个词进行输入(因为输入每个词到 Decoder,需要再做一些运算),所以在代码中会看到我使用了 for 循环,循环 trg_len-1 次(开头的 我手动输入,所以循环少一次)

训练过程中我使用了一种叫做 Teacher Forcing 的机制,保证训练速度的同时增加鲁棒性。

for 循环中应该要做哪些事呢?首先要将变量传入 Decoder,由于 Attention 的计算是在 Decoder 的内部进行的,所以我需要将dec_input、s、enc_output这三个变量传入 Decoder,Decoder 会返回dec_output以及新的s。之后根据概率对dec_output做 Teacher Forcing 即可,以下是代码实现:

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        # src = [src_len, batch_size]
        # trg = [trg_len, batch_size]
        # teacher_forcing_ratio is probability to use teacher forcing

        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        # tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        # enc_output is all hidden states of the input sequence, back and forwards
        # s is the final forward and backward hidden states, passed through a linear layer
        enc_output, s = self.encoder(src)

        # first input to the decoder is the <sos> tokens
        dec_input = trg[0,:]

        for t in range(1, trg_len):

            # insert dec_input token embedding, previous hidden state and all encoder hidden states
            # receive output tensor (predictions) and new hidden state
            dec_output, s = self.decoder(dec_input, s, enc_output)

            # place predictions in a tensor holding predictions for each token
            outputs[t] = dec_output

            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio

            # get the highest predicted token from our predictions
            top1 = dec_output.argmax(1) 

            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            dec_input = trg[t] if teacher_force else top1

        return outputs

讲到这里,seq2seq(with attention)各个模块的代码就结束了。

3. seq2seq模型完整实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time

#设置随机种子
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#加载英文,德语spacy模型
! python -m spacy download de
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

"""We create the tokenizers."""

def tokenize_de(text):
    # Tokenizes German text from a string into a list of strings
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    # Tokenizes English text from a string into a list of strings
    return [tok.text for tok in spacy_en.tokenizer(text)]

"""The fields remain the same as before."""

SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

#加载数据
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),fields = (SRC, TRG))

#建立词汇表
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

#定义训练装备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#建立数据迭代器
BATCH_SIZE = 128
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE,
    device = device)

#Encoder模型
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super(Encoder,self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src): 
        '''
        src = [src_len, batch_size]
        '''
        src = src.transpose(0, 1) # src = [batch_size, src_len]
        embedded = self.dropout(self.embedding(src)).transpose(0, 1) # embedded = [src_len, batch_size, emb_dim]

        # enc_output = [src_len, batch_size, hid_dim * num_directions]
        # enc_hidden = [n_layers * num_directions, batch_size, hid_dim]
        enc_output, enc_hidden = self.rnn(embedded) # if h_0 is not give, it will be set 0 acquiescently

        # enc_hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # enc_output are always from the last layer

        # enc_hidden [-2, :, : ] is the last of the forwards RNN 
        # enc_hidden [-1, :, : ] is the last of the backwards RNN

        # initial decoder hidden is final hidden state of the forwards and backwards 
        # encoder RNNs fed through a linear layer
        # s = [batch_size, dec_hid_dim]
        s = torch.tanh(self.fc(torch.cat((enc_hidden[-2,:,:], enc_hidden[-1,:,:]), dim = 1)))

        return enc_output, s

#Attention模型
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention,self).__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, s, enc_output):

        # s = [batch_size, dec_hid_dim]
        # enc_output = [src_len, batch_size, enc_hid_dim * 2]

        batch_size = enc_output.shape[1]
        src_len = enc_output.shape[0]

        # repeat decoder hidden state src_len times
        # s = [batch_size, src_len, dec_hid_dim]
        # enc_output = [batch_size, src_len, enc_hid_dim * 2]
        s = s.unsqueeze(1).repeat(1, src_len, 1)
        enc_output = enc_output.transpose(0, 1)

        # energy = [batch_size, src_len, dec_hid_dim]
        energy = torch.tanh(self.attn(torch.cat((s, enc_output), dim = 2)))

        # attention = [batch_size, src_len]
        attention = self.v(energy).squeeze(2)

        return F.softmax(attention, dim=1)

#Decoder模型
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super(Decoder,self).__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_input, s, enc_output):

        # dec_input = [batch_size]
        # s = [batch_size, dec_hid_dim]
        # enc_output = [src_len, batch_size, enc_hid_dim * 2]

        dec_input = dec_input.unsqueeze(1) # dec_input = [batch_size, 1]

        embedded = self.dropout(self.embedding(dec_input)).transpose(0, 1) # embedded = [1, batch_size, emb_dim]

        # a = [batch_size, 1, src_len]  
        a = self.attention(s, enc_output).unsqueeze(1)

        # enc_output = [batch_size, src_len, enc_hid_dim * 2]
        enc_output = enc_output.transpose(0, 1)

        # c = [1, batch_size, enc_hid_dim * 2]
        c = torch.bmm(a, enc_output).transpose(0, 1)

        # rnn_input = [1, batch_size, (enc_hid_dim * 2) + emb_dim]
        rnn_input = torch.cat((embedded, c), dim = 2)

        # dec_output = [src_len(=1), batch_size, dec_hid_dim]
        # dec_hidden = [n_layers * num_directions, batch_size, dec_hid_dim]
        dec_output, dec_hidden = self.rnn(rnn_input, s.unsqueeze(0))

        # embedded = [batch_size, emb_dim]
        # dec_output = [batch_size, dec_hid_dim]
        # c = [batch_size, enc_hid_dim * 2]
        embedded = embedded.squeeze(0)
        dec_output = dec_output.squeeze(0)
        c = c.squeeze(0)

        # pred = [batch_size, output_dim]
        pred = self.fc_out(torch.cat((dec_output, c, embedded), dim = 1))

        return pred, dec_hidden.squeeze(0)

#seq2seq模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        # src = [src_len, batch_size]
        # trg = [trg_len, batch_size]
        # teacher_forcing_ratio is probability to use teacher forcing

        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        # tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        # enc_output is all hidden states of the input sequence, back and forwards
        # s is the final forward and backward hidden states, passed through a linear layer
        enc_output, s = self.encoder(src)

        # first input to the decoder is the <sos> tokens
        dec_input = trg[0,:]

        for t in range(1, trg_len):

            # insert dec_input token embedding, previous hidden state and all encoder hidden states
            # receive output tensor (predictions) and new hidden state
            dec_output, s = self.decoder(dec_input, s, enc_output)

            # place predictions in a tensor holding predictions for each token
            outputs[t] = dec_output

            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio

            # get the highest predicted token from our predictions
            top1 = dec_output.argmax(1) 

            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            dec_input = trg[t] if teacher_force else top1

        return outputs

#训练seq2seq模型
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

#模型训练函数
def train(model, iterator, optimizer, criterion):
    model.train()    
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg # trg = [trg_len, batch_size]

        # pred = [trg_len, batch_size, pred_dim]
        pred = model(src, trg)

        pred_dim = pred.shape[-1]

        # trg = [(trg len - 1) * batch size]
        # pred = [(trg len - 1) * batch size, pred_dim]
        trg = trg[1:].view(-1)
        pred = pred[1:].view(-1, pred_dim)

        loss = criterion(pred, trg)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

#模型评估函数
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg # trg = [trg_len, batch_size]

            # output = [trg_len, batch_size, output_dim]
            output = model(src, trg, 0) # turn off teacher forcing

            output_dim = output.shape[-1]

            # trg = [(trg_len - 1) * batch_size]
            # output = [(trg_len - 1) * batch_size, output_dim]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

#计算训练时间
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

"""Then, we train our model, saving the parameters that give us the best validation loss."""

best_valid_loss = float('inf')

#模型训练
for epoch in range(10):
    start_time = time.time()

    train_loss = train(model, train_iterator, optimizer, criterion)
    valid_loss = evaluate(model, valid_iterator, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut3-model.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

#加载训练好的参数,测试模型
model.load_state_dict(torch.load('tut3-model.pt'))
test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')