convolution_rbm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:SeRanet 作者: corochann 项目源码 文件源码
def sample_h_given_v(self, v0_sample):
        """ get a sample of the hiddens by gibbs sampling
        :param v0_sample: Variable, see vis above
        :return:
        h1_mean:   Variable Matrix(batch_size, out_channels, image_height_out, image_width_out)
        h1_sample: Variable Matrix(batch_size, out_channels, image_height_out, image_width_out)
                   - actual sample for hidden units, populated by 0 or 1.
        """
        h1_mean = self.propup(v0_sample)
        xp = cuda.get_array_module(h1_mean.data)
        if xp == cuda.cupy:
            h1_sample = cuda.cupy.random.random_sample(size=h1_mean.data.shape)
            h1_sample[:] = h1_sample[:] < h1_mean.data[:]
        else:  # xp == np
            h1_sample = np.random.binomial(size=h1_mean.data.shape, n=1, p=h1_mean.data)
        return h1_mean, Variable(h1_sample.astype(xp.float32))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号