def get_self_distances(self, A, AB, A_to_AB=True):
A_half_1, A_half_2 = torch.chunk(A, 2, dim=2)
AB_half_1, AB_half_2 = torch.chunk(AB, 2, dim=2)
l_distance_A = \
self.get_individual_distance_loss(A_half_1, A_half_2,
AB_half_1, AB_half_2, A_to_AB)
return l_distance_A
评论列表
文章目录