corex.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:bio_corex 作者: gregversteeg 项目源码 文件源码
def calculate_mis(self, p_y_given_x, theta, Xm):
        mis = np.zeros((self.n_hidden, self.n_visible))
        sample = np.random.choice(np.arange(Xm.shape[0]), min(self.max_samples, Xm.shape[0]), replace=False)
        n_observed = np.sum(np.logical_not(ma.getmaskarray(Xm[sample])), axis=0)

        n_samples, n_visible = Xm.shape
        memory_size = float(n_samples * n_visible * self.n_hidden * self.dim_hidden * 64) / 1000**3  # GB
        batch_size = np.clip(int(self.ram * n_visible / memory_size), 1, n_visible)
        for i in range(0, n_visible, batch_size):
            log_marg_x = self.calculate_marginals_on_samples(theta[i:i+batch_size, ...], Xm[sample, i:i+batch_size])  # n_hidden, n_samples, n_visible, dim_hidden
            mis[:, i:i+batch_size] = np.einsum('ijl,ijkl->ik', p_y_given_x[:, sample, :], log_marg_x) / n_observed[i:i+batch_size][np.newaxis, :]
        return mis  # MI in nats
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号