def forward(self, s1, l1, s2, l2):
p_s1 = self.Embd(s1)
p_s2 = self.Embd(s2)
s1_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s1, l1)
s2_a_out = torch_util.auto_rnn_bilstm(self.lstm, p_s2, l2)
s1_max_out = torch_util.max_along_time(s1_a_out, l1)
s2_max_out = torch_util.max_along_time(s2_a_out, l2)
features = torch.cat([s1_max_out, s2_max_out, torch.abs(s1_max_out - s2_max_out), s1_max_out * s2_max_out], dim=1)
out = self.classifier(features)
return out
评论列表
文章目录