ESIM

参考链接
ESIM分为三个部分:input encoding,local inference modeling 和 inference composition。
第一部分就是用Embedding+LSTM对句子进行编码,提取上下文和句子的关系
第二部分是对两句话进行 alignment,这里是使用 soft_align_attention。

  1. def soft_attention_align(self, x1, x2, mask1, mask2):
  2. '''
  3. x1: batch_size * seq_len * dim
  4. x2: batch_size * seq_len * dim
  5. '''
  6. # attention: batch_size * seq_len * seq_len
  7. #计算句子的相似度
  8. attention = torch.matmul(x1, x2.transpose(1, 2))
  9. mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
  10. mask2 = mask2.float().masked_fill_(mask2, float('-inf'))
  11. # weight: batch_size * seq_len * seq_len
  12. # 得到句子的相似度加权的权重 :local inference
  13. weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
  14. x1_align = torch.matmul(weight1, x2)
  15. weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
  16. x2_align = torch.matmul(weight2, x1)
  17. # x_align: batch_size * seq_len * hidden_size
  18. return x1_align, x2_align

接下来是Enhancement of local inference information
计算差和点积,和原来的数据拼接之后作为下一级的输入

  1. def submul(self, x1, x2):
  2. mul = x1 * x2
  3. sub = x1 - x2
  4. return torch.cat([sub, mul], -1)
  1. def forward(self, input1,input2):
  2. # batch_size * seq_len
  3. sent1, sent2 = input1, input2
  4. mask1, mask2 = sent1.eq(0), sent2.eq(0)
  5. #print(sent1)(data)
  6. #print(mask1)(True or false by eq 0)
  7. # embeds: batch_size * seq_len => batch_size * seq_len * dim
  8. x1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)
  9. x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)
  10. # batch_size * seq_len * dim => batch_size * seq_len * hidden_size
  11. o1, _ = self.lstm1(x1)
  12. o2, _ = self.lstm1(x2)
  13. # Attention
  14. # batch_size * seq_len * hidden_size
  15. q1_align, q2_align = self.soft_attention_align(o1, o2, mask1, mask2)
  16. # Compose
  17. # batch_size * seq_len * (8 * hidden_size)
  18. q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)
  19. q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)
  20. # batch_size * seq_len * (2 * hidden_size)
  21. q1_compose, _ = self.lstm2(q1_combined)
  22. q2_compose, _ = self.lstm2(q2_combined)
  23. # Aggregate
  24. # input: batch_size * seq_len * (2 * hidden_size)
  25. # output: batch_size * (4 * hidden_size)
  26. q1_rep = self.apply_multiple(q1_compose)
  27. q2_rep = self.apply_multiple(q2_compose)
  28. # Classifier
  29. x = torch.cat([q1_rep, q2_rep], -1)
  30. similarity = self.fc(x)
  31. return similarity

上述的结果输入了第二层的LSTM
最后要做的就是inference composition
其实就是在句子的维度上做了一次池化操作,包括最大池化和平均池化。
池化我之前用的比较少,一开始认为这个池化操作在CNN里比较常见,这里是第一次用在RNN上

  1. ef apply_multiple(self, x):
  2. # input: batch_size * seq_len * (2 * hidden_size)
  3. p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
  4. p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
  5. # output: batch_size * (4 * hidden_size)
  6. return torch.cat([p1, p2], 1)

池化之后拼接输入最后的基本上的是全连接层的一个神经网络块里

  1. self.fc = nn.Sequential(
  2. nn.BatchNorm1d(self.hidden_size * 8),
  3. nn.Linear(self.hidden_size * 8, linear_size),
  4. nn.ELU(inplace=True),
  5. nn.BatchNorm1d(linear_size),
  6. nn.Dropout(self.dropout),
  7. nn.Linear(linear_size, linear_size),
  8. nn.ELU(inplace=True),
  9. nn.BatchNorm1d(linear_size),
  10. nn.Dropout(self.dropout),
  11. nn.Linear(linear_size, 2),
  12. nn.Softmax(dim=-1)
  13. )

也用了很多之前没用过的BatchNormal
最后用一个Sofmax,输出的就是两维的一个,较大的0或者是1就是结果

为啥 ESIM 效果会这么好呢?这里我提几个自己的想法,我觉得 ESIM 牛逼在它的 inter-sentence attention,就是上面代码中的 soft_align_attention,这一步中让要比较的两句话产生了交互。以前我见到的类似 Siamese 网络的结构,往往中间都没有交互,只是在最后一层求个余弦距离或者其他