vfm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:vfm 作者: cemoody 项目源码 文件源码
def term_bias(self, bs, train=True):
        """ Compute overall bias and broadcast to shape of batchsize
        """

        shape = (bs, 1,)
        # Bias is drawn from a Gaussian with given mu and log variance
        bs_mu = F.broadcast_to(self.bias_mu.b, shape)
        bs_lv = F.broadcast_to(self.bias_lv.b, shape)
        bias = F.flatten(F.gaussian(bs_mu, bs_lv))

        # Add a very negative log variance so we're sampling
        # from a very narrow distribution about the mean.
        # Useful for validation dataset when we want to only guess
        # the mean.
        if not train:
            bs_lv += self.lv_floor

        # Compute prior on the bias, so compute the KL div
        # from the KL(N(mu_bias, var_bias) | N(0, 1))
        kld = F.gaussian_kl_divergence(self.bias_mu.b, self.bias_lv.b)
        return bias, kld
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号