main.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pytorch_divcolor 作者: aditya12agd5 项目源码 文件源码
def mdn_loss(gmm_params, mu, stddev, batchsize):
  gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params)
  eps = Variable(torch.randn(stddev.size()).normal_()).cuda() 
  z = torch.add(mu, torch.mul(eps, stddev))
  z_flat = z.repeat(1, args.nmix)
  z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize)
  gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize)
  dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1))
  dist_all = dist_all.view(batchsize, args.nmix)
  dist_min, selectids = torch.min(dist_all, 1)
  gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1))
  gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min))
  gmm_loss_l2 = torch.mean(dist_min)
  return gmm_loss, gmm_loss_l2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号