def pack_to_matching_matrix(s1, s2, cat_only=[False, False]):
t1 = s1.size(0)
t2 = s2.size(0)
batch_size = s1.size(1)
d = s1.size(2)
expanded_p_s1 = s1.expand(t2, t1, batch_size, d)
expanded_p_s2 = s2.view(t2, 1, batch_size, d)
expanded_p_s2 = expanded_p_s2.expand(t2, t1, batch_size, d)
if not cat_only[0] and not cat_only[1]:
matrix = torch.cat((expanded_p_s1, expanded_p_s2), dim=3)
elif not cat_only[0] and cat_only[1]:
matrix = torch.cat((expanded_p_s1, expanded_p_s2, expanded_p_s1 * expanded_p_s2), dim=3)
else:
matrix = torch.cat((expanded_p_s1,
expanded_p_s2,
torch.abs(expanded_p_s1 - expanded_p_s2),
expanded_p_s1 * expanded_p_s2), dim=3)
# matrix = torch.cat((expanded_p_s1,
# expanded_p_s2), dim=3)
return matrix
评论列表
文章目录