draw_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:draw_pytorch 作者: chenzhaomin123 项目源码 文件源码
def loss(self,x):
        self.forward(x)
        criterion = nn.BCELoss()
        x_recons = self.sigmoid(self.cs[-1])
        Lx = criterion(x_recons,x) * self.A * self.B
        Lz = 0
        kl_terms = [0] * T
        for t in xrange(self.T):
            mu_2 = self.mus[t] * self.mus[t]
            sigma_2 = self.sigmas[t] * self.sigmas[t]
            logsigma = self.logsigmas[t]
            # Lz += (0.5 * (mu_2 + sigma_2 - 2 * logsigma))    # 11
            kl_terms[t] = 0.5 * torch.sum(mu_2+sigma_2-2 * logsigma,1) - self.T * 0.5
            Lz += kl_terms[t]
        # Lz -= self.T / 2
        Lz = torch.mean(Lz)    ####################################################
        loss = Lz + Lx    # 12
        return loss


    # correct
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号