def term_bias(self, bs, train=True):
""" Compute overall bias and broadcast to shape of batchsize
"""
shape = (bs, 1,)
# Bias is drawn from a Gaussian with given mu and log variance
bs_mu = F.broadcast_to(self.bias_mu.b, shape)
bs_lv = F.broadcast_to(self.bias_lv.b, shape)
bias = F.flatten(F.gaussian(bs_mu, bs_lv))
# Add a very negative log variance so we're sampling
# from a very narrow distribution about the mean.
# Useful for validation dataset when we want to only guess
# the mean.
if not train:
bs_lv += self.lv_floor
# Compute prior on the bias, so compute the KL div
# from the KL(N(mu_bias, var_bias) | N(0, 1))
kld = F.gaussian_kl_divergence(self.bias_mu.b, self.bias_lv.b)
return bias, kld
评论列表
文章目录