def get_distance_losses(self):
As = torch.split(self.real_A, 1)
Bs = torch.split(self.real_B, 1)
ABs = torch.split(self.fake_B, 1)
BAs = torch.split(self.fake_A, 1)
loss_distance_A = 0.0
loss_distance_B = 0.0
num_pairs = 0
min_length = min(len(As), len(Bs))
for i in xrange(min_length - 1):
for j in xrange(i + 1, min_length):
num_pairs += 1
loss_distance_A_ij, loss_distance_B_ij = \
self.get_individual_distance_loss(As[i], As[j],
ABs[i], ABs[j],
Bs[i], Bs[j],
BAs[i], BAs[j])
loss_distance_A += loss_distance_A_ij
loss_distance_B += loss_distance_B_ij
loss_distance_A = loss_distance_A / num_pairs
loss_distance_B = loss_distance_B / num_pairs
return loss_distance_A, loss_distance_B
评论列表
文章目录