def mdn_loss(gmm_params, mu, stddev, batchsize):
gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params)
eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
z = torch.add(mu, torch.mul(eps, stddev))
z_flat = z.repeat(1, args.nmix)
z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize)
gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize)
dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1))
dist_all = dist_all.view(batchsize, args.nmix)
dist_min, selectids = torch.min(dist_all, 1)
gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1))
gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min))
gmm_loss_l2 = torch.mean(dist_min)
return gmm_loss, gmm_loss_l2
评论列表
文章目录