nn.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:pyprob 作者: probprog 项目源码 文件源码
def loss(self, x, samples):
        _, proposal_output = self.forward(x, samples)
        batch_size = len(samples)
        modes = proposal_output[:, 0]
        certainties = proposal_output[:, 1] + 2
        alphas = modes * (certainties - 2) + 1
        betas = (1 - modes) * (certainties - 2) + 1
        beta_funs = util.beta(alphas, betas)
        l = 0
        for b in range(batch_size):
            value = samples[b].value[0]
            alpha = alphas[b]
            beta = betas[b]
            beta_fun = beta_funs[b]
            l -= (alpha - 1) * np.log(value + util.epsilon) + (beta - 1) * np.log(1 - value + util.epsilon) - torch.log(beta_fun + util.epsilon)
        return l
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号