ESIM
参考链接
ESIM分为三个部分:input encoding,local inference modeling 和 inference composition。
第一部分就是用Embedding+LSTM对句子进行编码,提取上下文和句子的关系
第二部分是对两句话进行 alignment,这里是使用 soft_align_attention。
def soft_attention_align(self, x1, x2, mask1, mask2):'''x1: batch_size * seq_len * dimx2: batch_size * seq_len * dim'''# attention: batch_size * seq_len * seq_len#计算句子的相似度attention = torch.matmul(x1, x2.transpose(1, 2))mask1 = mask1.float().masked_fill_(mask1, float('-inf'))mask2 = mask2.float().masked_fill_(mask2, float('-inf'))# weight: batch_size * seq_len * seq_len# 得到句子的相似度加权的权重 :local inferenceweight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)x1_align = torch.matmul(weight1, x2)weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)x2_align = torch.matmul(weight2, x1)# x_align: batch_size * seq_len * hidden_sizereturn x1_align, x2_align
接下来是Enhancement of local inference information
计算差和点积,和原来的数据拼接之后作为下一级的输入
def submul(self, x1, x2):mul = x1 * x2sub = x1 - x2return torch.cat([sub, mul], -1)
def forward(self, input1,input2):# batch_size * seq_lensent1, sent2 = input1, input2mask1, mask2 = sent1.eq(0), sent2.eq(0)#print(sent1)(data)#print(mask1)(True or false by eq 0)# embeds: batch_size * seq_len => batch_size * seq_len * dimx1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)# batch_size * seq_len * dim => batch_size * seq_len * hidden_sizeo1, _ = self.lstm1(x1)o2, _ = self.lstm1(x2)# Attention# batch_size * seq_len * hidden_sizeq1_align, q2_align = self.soft_attention_align(o1, o2, mask1, mask2)# Compose# batch_size * seq_len * (8 * hidden_size)q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)# batch_size * seq_len * (2 * hidden_size)q1_compose, _ = self.lstm2(q1_combined)q2_compose, _ = self.lstm2(q2_combined)# Aggregate# input: batch_size * seq_len * (2 * hidden_size)# output: batch_size * (4 * hidden_size)q1_rep = self.apply_multiple(q1_compose)q2_rep = self.apply_multiple(q2_compose)# Classifierx = torch.cat([q1_rep, q2_rep], -1)similarity = self.fc(x)return similarity
上述的结果输入了第二层的LSTM
最后要做的就是inference composition
其实就是在句子的维度上做了一次池化操作,包括最大池化和平均池化。
池化我之前用的比较少,一开始认为这个池化操作在CNN里比较常见,这里是第一次用在RNN上
ef apply_multiple(self, x):# input: batch_size * seq_len * (2 * hidden_size)p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)# output: batch_size * (4 * hidden_size)return torch.cat([p1, p2], 1)
池化之后拼接输入最后的基本上的是全连接层的一个神经网络块里
self.fc = nn.Sequential(nn.BatchNorm1d(self.hidden_size * 8),nn.Linear(self.hidden_size * 8, linear_size),nn.ELU(inplace=True),nn.BatchNorm1d(linear_size),nn.Dropout(self.dropout),nn.Linear(linear_size, linear_size),nn.ELU(inplace=True),nn.BatchNorm1d(linear_size),nn.Dropout(self.dropout),nn.Linear(linear_size, 2),nn.Softmax(dim=-1))
也用了很多之前没用过的BatchNormal
最后用一个Sofmax,输出的就是两维的一个,较大的0或者是1就是结果
为啥 ESIM 效果会这么好呢?这里我提几个自己的想法,我觉得 ESIM 牛逼在它的 inter-sentence attention,就是上面代码中的 soft_align_attention,这一步中让要比较的两句话产生了交互。以前我见到的类似 Siamese 网络的结构,往往中间都没有交互,只是在最后一层求个余弦距离或者其他
