def forward(self, input_1, input_2):
"""
:param : input_1
Size is (*, hidden_size)
:param input_2:
Size is (*, hidden_size)
:return:
Merged vectors, size is (*, 4*hidden size)
"""
assert input_1.size(-1) == input_2.size(-1)
mult_combined_vec = torch.mul(input_1, input_2)
diff_combined_vec = torch.abs(input_1 - input_2)
combined_vec = torch.cat((input_1,
input_2,
mult_combined_vec,
diff_combined_vec), input_1.dim()-1)
return combined_vec
评论列表
文章目录