networks.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:comprehend 作者: Fenugreek 项目源码 文件源码
def sample_h_given_v(self, v, eps=1e-5):

        mean_h = self.mean_h.eval(feed_dict={self.visible: v})

        if not self.beta_sampling:
            rnds = np.random.randn(mean_h.shape[0], mean_h.shape[1]).astype(v.dtype)
            return np.clip(mean_h + rnds * self.sigma, eps, 1. - eps)

        mhhm = mean_h * (1 - mean_h)

        # Handle the cases where h is close to 0.0 or 1.0
        # Normally beta distribution will give a sample close to 0.0 or 1.0,
        # breaking requirement that there be some variation (sample dispersion
        # close to 0.0 when it ought to be close to self.sigma).
        small_h = self.sigma**2 > mhhm
        small_count = np.sum(small_h)
        if small_count:
            # We randomize these cases with probability self.sigma.
            switch = np.random.rand(small_count) < self.sigma
            if np.sum(switch):
                mean_h[small_h][switch] = np.random.rand(np.sum(switch))
            mhhm = mean_h * (1 - mean_h)

        var_h = np.fmin(mhhm, self.sigma**2)
        operand = (mhhm + 1.5 * eps) / (var_h + eps) - 1
        alpha = mean_h * operand + eps
        beta = (1 - mean_h) * operand + eps

        return np.random.beta(alpha, beta).astype(v.dtype)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号