def get_distance_losses(self, A, AB, A_to_AB=True ):
As = torch.split(A, 1)
ABs = torch.split(AB, 1)
loss_distance_A = 0.0
num_pairs = 0
min_length = len(As)
for i in xrange(min_length - 1):
for j in xrange(i + 1, min_length):
num_pairs += 1
loss_distance_A_ij = \
self.get_individual_distance_loss(As[i], As[j],
ABs[i], ABs[j], A_to_AB)
loss_distance_A += loss_distance_A_ij
loss_distance_A = loss_distance_A / num_pairs
return loss_distance_A
评论列表
文章目录