auto_vfm.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vfm 作者: cemoody 项目源码 文件源码
def term_feat(self, iloc, jloc, ival, jval, bs, nf, train=True):
        # Change all of the shapes to form interaction vectors
        shape = (bs, nf * 2, self.n_dim)
        feat_mu_vec = F.broadcast_to(self.feat_mu_vec.b, shape)
        feat_lv_vec = F.broadcast_to(self.feat_lv_vec.b, shape)
        if not train:
            feat_lv_vec += self.lv_floor

        # Construct the interaction mean and variance
        # iloc is (bs, nf), feat(iloc) is (bs, nf, ndim) and
        # dot(feat, feat) is (bs, nf)
        ivec = F.gaussian(feat_mu_vec + self.feat_delta_mu(iloc),
                          feat_lv_vec + self.feat_delta_lv(iloc))
        jvec = F.gaussian(feat_mu_vec + self.feat_delta_mu(jloc),
                          feat_lv_vec + self.feat_delta_lv(jloc))
        # feat is (bs, )
        feat = dot(F.sum(ivec * jvec, axis=2), ival * jval)

        # Compute the KLD for the group mean vector and variance vector
        # KL(N(group mu, group lv) || N(0, hyper_lv))
        # hyper_lv ~ gamma(1, 1)
        kldg = F.sum(kl_div(self.feat_mu_vec.b, self.feat_lv_vec.b,
                            self.hyper_feat_lv_vec.b))
        # Compute deviations from hyperprior
        # KL(N(delta_i, delta_i lv) || N(0, hyper_delta_lv))
        # hyper_delta_lv ~ gamma(1, 1)
        kldi = F.sum(kl_div(self.feat_delta_mu.W, self.feat_delta_lv.W,
                            self.hyper_feat_delta_lv.b))
        # Hyperprior penalty for log(var) ~ Gamma(alpha=1, beta=1)
        # Gamma(log(var) | alpha=1, beta=1) = -log(var)
        # The loss function will attempt to make log(var) as negative as 
        # possible which will in turn make the variance as small as possible
        # The sum just casts a 1D vector to a scalar
        hyperg = -F.sum(self.hyper_feat_lv_vec.b)
        hyperi = -F.sum(self.hyper_feat_delta_lv.b)
        return feat, kldg, kldi, hyperg, hyperi
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号