SBN.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:VIMCO 作者: y0ast 项目源码 文件源码
def compile_sampling(self, data_train, data_valid, data_test, training_n_samples):
        X = tt.matrix('X')
        batch = tt.iscalar('batch')
        n_samples = tt.iscalar('n_samples')

        n_layers = len(self.layers)
        samples = [None] * n_layers

        samples[0] = replicate_batch(X, n_samples)

        if "gpu" in theano.config.device:
            from theano.sandbox import rng_mrg
            srng = rng_mrg.MRG_RandomStreams(seed=42)
        else:
            srng = tt.shared_randomstreams.RandomStreams(seed=42)

        for layer in range(n_layers - 1):
            samples[layer + 1] = self.compute_samples(srng, samples[layer], layer)


        givens = dict()
        givens[X] = data_valid[batch * self.batch_size:(batch + 1) * self.batch_size]
        self.sample_convergence = theano.function([batch, n_samples], samples, givens=givens)

        givens[n_samples] = np.int32(training_n_samples)
        givens[X] = data_train[batch * self.batch_size:(batch + 1) * self.batch_size]
        self.sample_train = theano.function([batch], samples, givens=givens)

        givens[X] = data_valid[batch * self.batch_size:(batch + 1) * self.batch_size]
        self.sample_valid = theano.function([batch], samples, givens=givens)

        givens[X] = data_test[batch * self.batch_size:(batch + 1) * self.batch_size]
        self.sample_test = theano.function([batch], samples, givens=givens)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号