ESIM
参考链接
ESIM分为三个部分:input encoding,local inference modeling 和 inference composition。
第一部分就是用Embedding+LSTM对句子进行编码,提取上下文和句子的关系
第二部分是对两句话进行 alignment,这里是使用 soft_align_attention。
def soft_attention_align(self, x1, x2, mask1, mask2):
'''
x1: batch_size * seq_len * dim
x2: batch_size * seq_len * dim
'''
# attention: batch_size * seq_len * seq_len
#计算句子的相似度
attention = torch.matmul(x1, x2.transpose(1, 2))
mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
mask2 = mask2.float().masked_fill_(mask2, float('-inf'))
# weight: batch_size * seq_len * seq_len
# 得到句子的相似度加权的权重 :local inference
weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
x1_align = torch.matmul(weight1, x2)
weight2 = F.softmax(attention.transpose(1, 2) + mask1.unsqueeze(1), dim=-1)
x2_align = torch.matmul(weight2, x1)
# x_align: batch_size * seq_len * hidden_size
return x1_align, x2_align
接下来是Enhancement of local inference information
计算差和点积,和原来的数据拼接之后作为下一级的输入
def submul(self, x1, x2):
mul = x1 * x2
sub = x1 - x2
return torch.cat([sub, mul], -1)
def forward(self, input1,input2):
# batch_size * seq_len
sent1, sent2 = input1, input2
mask1, mask2 = sent1.eq(0), sent2.eq(0)
#print(sent1)(data)
#print(mask1)(True or false by eq 0)
# embeds: batch_size * seq_len => batch_size * seq_len * dim
x1 = self.bn_embeds(self.embeds(sent1).transpose(1, 2).contiguous()).transpose(1, 2)
x2 = self.bn_embeds(self.embeds(sent2).transpose(1, 2).contiguous()).transpose(1, 2)
# batch_size * seq_len * dim => batch_size * seq_len * hidden_size
o1, _ = self.lstm1(x1)
o2, _ = self.lstm1(x2)
# Attention
# batch_size * seq_len * hidden_size
q1_align, q2_align = self.soft_attention_align(o1, o2, mask1, mask2)
# Compose
# batch_size * seq_len * (8 * hidden_size)
q1_combined = torch.cat([o1, q1_align, self.submul(o1, q1_align)], -1)
q2_combined = torch.cat([o2, q2_align, self.submul(o2, q2_align)], -1)
# batch_size * seq_len * (2 * hidden_size)
q1_compose, _ = self.lstm2(q1_combined)
q2_compose, _ = self.lstm2(q2_combined)
# Aggregate
# input: batch_size * seq_len * (2 * hidden_size)
# output: batch_size * (4 * hidden_size)
q1_rep = self.apply_multiple(q1_compose)
q2_rep = self.apply_multiple(q2_compose)
# Classifier
x = torch.cat([q1_rep, q2_rep], -1)
similarity = self.fc(x)
return similarity
上述的结果输入了第二层的LSTM
最后要做的就是inference composition
其实就是在句子的维度上做了一次池化操作,包括最大池化和平均池化。
池化我之前用的比较少,一开始认为这个池化操作在CNN里比较常见,这里是第一次用在RNN上
ef apply_multiple(self, x):
# input: batch_size * seq_len * (2 * hidden_size)
p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
# output: batch_size * (4 * hidden_size)
return torch.cat([p1, p2], 1)
池化之后拼接输入最后的基本上的是全连接层的一个神经网络块里
self.fc = nn.Sequential(
nn.BatchNorm1d(self.hidden_size * 8),
nn.Linear(self.hidden_size * 8, linear_size),
nn.ELU(inplace=True),
nn.BatchNorm1d(linear_size),
nn.Dropout(self.dropout),
nn.Linear(linear_size, linear_size),
nn.ELU(inplace=True),
nn.BatchNorm1d(linear_size),
nn.Dropout(self.dropout),
nn.Linear(linear_size, 2),
nn.Softmax(dim=-1)
)
也用了很多之前没用过的BatchNormal
最后用一个Sofmax,输出的就是两维的一个,较大的0或者是1就是结果
为啥 ESIM 效果会这么好呢?这里我提几个自己的想法,我觉得 ESIM 牛逼在它的 inter-sentence attention,就是上面代码中的 soft_align_attention,这一步中让要比较的两句话产生了交互。以前我见到的类似 Siamese 网络的结构,往往中间都没有交互,只是在最后一层求个余弦距离或者其他