vfm.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:vfm 作者: cemoody 项目源码 文件源码
def term_feat(self, iloc, jloc, ival, jval, bs, nf, train=True):
        # Change all of the shapes to form interaction vectors
        shape = (bs, nf * 2, self.n_dim)
        feat_mu_vec = F.broadcast_to(self.feat_mu_vec.b, shape)
        feat_lv_vec = F.broadcast_to(self.feat_lv_vec.b, shape)
        if not train:
            feat_lv_vec += self.lv_floor

        # Construct the interaction mean and variance
        # iloc is (bs, nf), feat(iloc) is (bs, nf, ndim) and
        # dot(feat, feat) is (bs, nf)
        ivec = F.gaussian(feat_mu_vec + self.feat_delta_mu(iloc),
                          feat_lv_vec + self.feat_delta_lv(iloc))
        jvec = F.gaussian(feat_mu_vec + self.feat_delta_mu(jloc),
                          feat_lv_vec + self.feat_delta_lv(jloc))
        # feat is (bs, )
        feat = dot(F.sum(ivec * jvec, axis=2), ival * jval)

        # Compute the KLD for the group mean vector and variance vector
        kld1 = F.gaussian_kl_divergence(self.feat_mu_vec.b, self.feat_lv_vec.b)
        # Compute the KLD for vector deviations from the group mean and var
        kld2 = F.gaussian_kl_divergence(self.feat_delta_mu.W,
                                        self.feat_delta_lv.W)
        return feat, kld1 + kld2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号