def term_feat(self, iloc, jloc, ival, jval, bs, nf, train=True):
# Change all of the shapes to form interaction vectors
shape = (bs, nf * 2, self.n_dim)
feat_mu_vec = F.broadcast_to(self.feat_mu_vec.b, shape)
feat_lv_vec = F.broadcast_to(self.feat_lv_vec.b, shape)
if not train:
feat_lv_vec += self.lv_floor
# Construct the interaction mean and variance
# iloc is (bs, nf), feat(iloc) is (bs, nf, ndim) and
# dot(feat, feat) is (bs, nf)
ivec = F.gaussian(feat_mu_vec + self.feat_delta_mu(iloc),
feat_lv_vec + self.feat_delta_lv(iloc))
jvec = F.gaussian(feat_mu_vec + self.feat_delta_mu(jloc),
feat_lv_vec + self.feat_delta_lv(jloc))
# feat is (bs, )
feat = dot(F.sum(ivec * jvec, axis=2), ival * jval)
# Compute the KLD for the group mean vector and variance vector
# KL(N(group mu, group lv) || N(0, hyper_lv))
# hyper_lv ~ gamma(1, 1)
kldg = F.sum(kl_div(self.feat_mu_vec.b, self.feat_lv_vec.b,
self.hyper_feat_lv_vec.b))
# Compute deviations from hyperprior
# KL(N(delta_i, delta_i lv) || N(0, hyper_delta_lv))
# hyper_delta_lv ~ gamma(1, 1)
kldi = F.sum(kl_div(self.feat_delta_mu.W, self.feat_delta_lv.W,
self.hyper_feat_delta_lv.b))
# Hyperprior penalty for log(var) ~ Gamma(alpha=1, beta=1)
# Gamma(log(var) | alpha=1, beta=1) = -log(var)
# The loss function will attempt to make log(var) as negative as
# possible which will in turn make the variance as small as possible
# The sum just casts a 1D vector to a scalar
hyperg = -F.sum(self.hyper_feat_lv_vec.b)
hyperi = -F.sum(self.hyper_feat_delta_lv.b)
return feat, kldg, kldi, hyperg, hyperi
评论列表
文章目录