图解Transformer


前言

  • Google 提出了 Transformer 模型,用 Self Attention 的结构,取代了以往 NLP 任务中的 RNN 网络结构
  • 优点
    • 使得模型训练过程能够并行计算
  • Transformer 使用了 Seq2Seq任务中常用的结构——包括两个部分:Encoder 和 Decoder。


一、从整体宏观来理解 Transformer

  • Transformer 可以拆分为 2 部分:
    • 编码部分(encoding component)
      • 由多层的编码器(Encoder)组成(Transformer 的论文中使用了 6 层编码器,这里的层数 6 并不是固定的)
        • 每一个编码器,可以分为 2 层
          • Self-Attention Layer
          • Feed Forward Neural Network(前馈神经网络,缩写为 FFNN)
    • 解码部分(decoding component)
      • 由多层的解码器(Encoder)组成(Transformer 的论文中使用了 6 层解码器,这里的层数 6 并不是固定的)
        • 每一个解码器,可以分为 3层
          • Self-Attention Layer
          • Encoder-Decoder Attention 层,这个层能帮助解码器聚焦于输入句子的相关部分(类似于 seq2seq 模型 中的 Attention)。
          • Feed Forward Neural Network(前馈神经网络,缩写为 FFNN)

二、从细节来理解 Transformer

  • Transformer 的一个重要特性:
    • 每个位置的词向量经过编码器都有自己单独的路径。因此这些词向量在经过 Feed Forward 层中可以并行计算。

      三、 Self-Attention 整体理解


  1. # Standard PyTorch imports
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import math, copy, time
  7. from torch.autograd import Variable
  8. # For plots
  9. %matplotlib inline
  10. import matplotlib.pyplot as plt
  11. ############################Model Architecture
  12. class EncoderDecoder(nn.Module):
  13. """
  14. A standard Encoder-Decoder architecture. Base for this and many
  15. other models.
  16. """
  17. def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
  18. super(EncoderDecoder, self).__init__()
  19. self.encoder = encoder
  20. self.decoder = decoder
  21. self.src_embed = src_embed
  22. self.tgt_embed = tgt_embed
  23. self.generator = generator
  24. def forward(self, src, tgt, src_mask, tgt_mask):
  25. "Take in and process masked src and target sequences."
  26. return self.decode(self.encode(src, src_mask), src_mask,
  27. tgt, tgt_mask)
  28. def encode(self, src, src_mask):
  29. return self.encoder(self.src_embed(src), src_mask)
  30. def decode(self, memory, src_mask, tgt, tgt_mask):
  31. return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
  32. class Generator(nn.Module):
  33. "Define standard linear + softmax generation step."
  34. def __init__(self, d_model, vocab):
  35. super(Generator, self).__init__()
  36. self.proj = nn.Linear(d_model, vocab)
  37. def forward(self, x):
  38. return F.log_softmax(self.proj(x), dim=-1)
  39. ##############################Encoder and Decoder Stacks
  40. ##############################Encoder
  41. def clones(module, N):
  42. "Produce N identical layers."
  43. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  44. class Encoder(nn.Module):
  45. "Core encoder is a stack of N layers"
  46. def __init__(self, layer, N):
  47. super(Encoder, self).__init__()
  48. self.layers = clones(layer, N)
  49. self.norm = LayerNorm(layer.size)
  50. def forward(self, x, mask):
  51. "Pass the input (and mask) through each layer in turn."
  52. for layer in self.layers:
  53. x = layer(x, mask)
  54. return self.norm(x)
  55. class LayerNorm(nn.Module):
  56. "Construct a layernorm module (See citation for details)."
  57. def __init__(self, features, eps=1e-6):
  58. super(LayerNorm, self).__init__()
  59. self.a_2 = nn.Parameter(torch.ones(features))
  60. self.b_2 = nn.Parameter(torch.zeros(features))
  61. self.eps = eps
  62. def forward(self, x):
  63. mean = x.mean(-1, keepdim=True)
  64. std = x.std(-1, keepdim=True)
  65. return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
  66. class SublayerConnection(nn.Module):
  67. """
  68. A residual connection followed by a layer norm.
  69. Note for code simplicity the norm is first as opposed to last.
  70. """
  71. def __init__(self, size, dropout):
  72. super(SublayerConnection, self).__init__()
  73. self.norm = LayerNorm(size)
  74. self.dropout = nn.Dropout(dropout)
  75. def forward(self, x, sublayer):
  76. "Apply residual connection to any sublayer with the same size."
  77. return x + self.dropout(sublayer(self.norm(x)))
  78. class EncoderLayer(nn.Module):
  79. "Encoder is made up of self-attn and feed forward (defined below)"
  80. def __init__(self, size, self_attn, feed_forward, dropout):
  81. super(EncoderLayer, self).__init__()
  82. self.self_attn = self_attn
  83. self.feed_forward = feed_forward
  84. self.sublayer = clones(SublayerConnection(size, dropout), 2)
  85. self.size = size
  86. def forward(self, x, mask):
  87. "Follow Figure 1 (left) for connections."
  88. x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
  89. return self.sublayer[1](x, self.feed_forward)
  90. ##############################Decoder
  91. class Decoder(nn.Module):
  92. "Generic N layer decoder with masking."
  93. def __init__(self, layer, N):
  94. super(Decoder, self).__init__()
  95. self.layers = clones(layer, N)
  96. self.norm = LayerNorm(layer.size)
  97. def forward(self, x, memory, src_mask, tgt_mask):
  98. for layer in self.layers:
  99. x = layer(x, memory, src_mask, tgt_mask)
  100. return self.norm(x)
  101. class DecoderLayer(nn.Module):
  102. "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
  103. def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
  104. super(DecoderLayer, self).__init__()
  105. self.size = size
  106. self.self_attn = self_attn
  107. self.src_attn = src_attn
  108. self.feed_forward = feed_forward
  109. self.sublayer = clones(SublayerConnection(size, dropout), 3)
  110. def forward(self, x, memory, src_mask, tgt_mask):
  111. "Follow Figure 1 (right) for connections."
  112. m = memory
  113. x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
  114. x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
  115. return self.sublayer[2](x, self.feed_forward)
  116. def subsequent_mask(size):
  117. "Mask out subsequent positions."
  118. attn_shape = (1, size, size)
  119. subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') # 返回函数的上三角矩阵
  120. return torch.from_numpy(subsequent_mask) == 0
  121. #################################Attention
  122. #####################Scaled Dot-Product Attention
  123. def attention(query, key, value, mask=None, dropout=None):
  124. "Compute 'Scaled Dot Product Attention'"
  125. d_k = query.size(-1)
  126. scores = torch.matmul(query, key.transpose(-2, -1)) \
  127. / math.sqrt(d_k)
  128. if mask is not None:
  129. scores = scores.masked_fill(mask == 0, -1e9)
  130. p_attn = F.softmax(scores, dim = -1)
  131. if dropout is not None:
  132. p_attn = dropout(p_attn)
  133. return torch.matmul(p_attn, value), p_attn
  134. ###################Multi-head attention
  135. class MultiHeadedAttention(nn.Module):
  136. def __init__(self, h, d_model, dropout=0.1):
  137. "Take in model size and number of heads."
  138. super(MultiHeadedAttention, self).__init__()
  139. assert d_model % h == 0
  140. # We assume d_v always equals d_k
  141. self.d_k = d_model // h
  142. self.h = h
  143. self.linears = clones(nn.Linear(d_model, d_model), 4)
  144. self.attn = None
  145. self.dropout = nn.Dropout(p=dropout)
  146. def forward(self, query, key, value, mask=None):
  147. "Implements Figure 2"
  148. if mask is not None:
  149. # Same mask applied to all h heads.
  150. mask = mask.unsqueeze(1)
  151. nbatches = query.size(0)
  152. # 1) 这一步qkv变化:[batch, L, d_model] ->[batch, h, L, d_model/h]
  153. query, key, value = \
  154. [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
  155. for l, x in zip(self.linears, (query, key, value))]
  156. # 2) 计算注意力attn 得到attn*v 与attn
  157. x, self.attn = attention(query, key, value, mask=mask,
  158. dropout=self.dropout)
  159. # 3) 上一步的结果合并在一起还原成原始输入序列的形状
  160. x = x.transpose(1, 2).contiguous() \
  161. .view(nbatches, -1, self.h * self.d_k)
  162. # 最后再过一个线性层
  163. return self.linears[-1](x)
  164. #######################Position-wise Feed-Forward Networks
  165. class PositionwiseFeedForward(nn.Module):
  166. "Implements FFN equation."
  167. def __init__(self, d_model, d_ff, dropout=0.1):
  168. super(PositionwiseFeedForward, self).__init__()
  169. self.w_1 = nn.Linear(d_model, d_ff)
  170. self.w_2 = nn.Linear(d_ff, d_model)
  171. self.dropout = nn.Dropout(dropout)
  172. def forward(self, x):
  173. return self.w_2(self.dropout(F.relu(self.w_1(x))))
  174. ######################Embeddings and Softmax
  175. class Embeddings(nn.Module):
  176. def __init__(self, d_model, vocab):
  177. super(Embeddings, self).__init__()
  178. self.lut = nn.Embedding(vocab, d_model)
  179. self.d_model = d_model
  180. def forward(self, x):
  181. return self.lut(x) * math.sqrt(self.d_model)
  182. ########################Positional Encoding
  183. class PositionalEncoding(nn.Module):
  184. "Implement the PE function."
  185. def __init__(self, d_model, dropout, max_len=5000):
  186. super(PositionalEncoding, self).__init__()
  187. self.dropout = nn.Dropout(p=dropout)
  188. # Compute the positional encodings once in log space.
  189. pe = torch.zeros(max_len, d_model)
  190. position = torch.arange(0, max_len).unsqueeze(1)
  191. div_term = torch.exp(torch.arange(0, d_model, 2) *
  192. -(math.log(10000.0) / d_model))
  193. pe[:, 0::2] = torch.sin(position * div_term)
  194. pe[:, 1::2] = torch.cos(position * div_term)
  195. pe = pe.unsqueeze(0)
  196. self.register_buffer('pe', pe)
  197. def forward(self, x):
  198. x = x + Variable(self.pe[:, :x.size(1)],
  199. requires_grad=False)
  200. return self.dropout(x)
  201. ################################Full Model
  202. def make_model(src_vocab, tgt_vocab, N=6,
  203. d_model=512, d_ff=2048, h=8, dropout=0.1):
  204. "Helper: Construct a model from hyperparameters."
  205. c = copy.deepcopy
  206. attn = MultiHeadedAttention(h, d_model)
  207. ff = PositionwiseFeedForward(d_model, d_ff, dropout)
  208. position = PositionalEncoding(d_model, dropout)
  209. model = EncoderDecoder(
  210. Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
  211. Decoder(DecoderLayer(d_model, c(attn), c(attn),
  212. c(ff), dropout), N),
  213. nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
  214. nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
  215. Generator(d_model, tgt_vocab))
  216. # This was important from their code.
  217. # Initialize parameters with Glorot / fan_avg.
  218. for p in model.parameters():
  219. if p.dim() > 1:
  220. nn.init.xavier_uniform(p)
  221. return model
  222. tmp_model = make_model(10, 10, 2)
  223. tmp_model