nn.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pyprob 作者: probprog 项目源码 文件源码
def loss(self, x, samples):
        _, proposal_output = self.forward(x, samples)
        batch_size = len(samples)
        means = proposal_output[:,0:self.mixture_components]
        stds = proposal_output[:,self.mixture_components:2*self.mixture_components]
        coeffs = proposal_output[:,2*self.mixture_components:3*self.mixture_components]
        l = 0
        for b in range(batch_size):
            value = samples[b].value[0]
            prior_min = samples[b].distribution.prior_min
            prior_max = samples[b].distribution.prior_max
            ll = 0
            for c in range(self.mixture_components):
                mean = means[b,c]
                std = stds[b,c]
                coeff = coeffs[b,c]
                xi = (value - mean) / std
                phi_min = 0.5 * (1 + util.erf(((prior_min - mean) / std) * util.one_over_sqrt_two))
                phi_max = 0.5 * (1 + util.erf(((prior_max - mean) / std) * util.one_over_sqrt_two))
                ll += coeff * util.one_over_sqrt_two_pi * torch.exp(-0.5 * xi * xi) / (std * (phi_max - phi_min))
            l -= torch.log(ll + util.epsilon)
        return l
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号