distance_gan_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:DistanceGAN 作者: sagiebenaim 项目源码 文件源码
def get_distance_losses(self):

        As = torch.split(self.real_A, 1)
        Bs = torch.split(self.real_B, 1)
        ABs = torch.split(self.fake_B, 1)
        BAs = torch.split(self.fake_A, 1)

        loss_distance_A = 0.0
        loss_distance_B = 0.0
        num_pairs = 0
        min_length = min(len(As), len(Bs))

        for i in xrange(min_length - 1):
            for j in xrange(i + 1, min_length):
                num_pairs += 1
                loss_distance_A_ij, loss_distance_B_ij = \
                    self.get_individual_distance_loss(As[i], As[j],
                                                      ABs[i], ABs[j],
                                                      Bs[i], Bs[j],
                                                      BAs[i], BAs[j])

                loss_distance_A += loss_distance_A_ij
                loss_distance_B += loss_distance_B_ij

        loss_distance_A = loss_distance_A / num_pairs
        loss_distance_B = loss_distance_B / num_pairs

        return loss_distance_A, loss_distance_B
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号