def term_feat(self, iloc, jloc, ival, jval, bs, nf, train=True):
# Change all of the shapes to form interaction vectors
shape = (bs, nf * 2, self.n_dim)
feat_mu_vec = F.broadcast_to(self.feat_mu_vec.b, shape)
feat_lv_vec = F.broadcast_to(self.feat_lv_vec.b, shape)
if not train:
feat_lv_vec += self.lv_floor
# Construct the interaction mean and variance
# iloc is (bs, nf), feat(iloc) is (bs, nf, ndim) and
# dot(feat, feat) is (bs, nf)
ivec = F.gaussian(feat_mu_vec + self.feat_delta_mu(iloc),
feat_lv_vec + self.feat_delta_lv(iloc))
jvec = F.gaussian(feat_mu_vec + self.feat_delta_mu(jloc),
feat_lv_vec + self.feat_delta_lv(jloc))
# feat is (bs, )
feat = dot(F.sum(ivec * jvec, axis=2), ival * jval)
# Compute the KLD for the group mean vector and variance vector
kld1 = F.gaussian_kl_divergence(self.feat_mu_vec.b, self.feat_lv_vec.b)
# Compute the KLD for vector deviations from the group mean and var
kld2 = F.gaussian_kl_divergence(self.feat_delta_mu.W,
self.feat_delta_lv.W)
return feat, kld1 + kld2
评论列表
文章目录