DBN.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepLearningTutorialForChinese 作者: zhaoyu611 项目源码 文件源码
def pretraining_functions(self,train_set_x,batch_size,k):
        """
        ??????????????????????????minibatch???
        ????RBM,?????minibatch???????

        train_set_x: theano.tensor.TensorType ??????
        batch_size: int minibatch???
        k:  int CD-k/PCD-k?Gibbs????
        """
        index=T.lscalar('index') #minibatch???
        learning_rate=T.scalar('lr') #???
        #bathes??
        n_batches=train_set_x.get_value(borrow=True).shape[0]/batch_size
        #??index????batch
        batch_begin=index*batch_size
        #??index????batch
        batch_end=batch_begin+batch_size

        pretrain_fns=[]
        for rbm in self.rbm_layers:  #??????RBM
            #??????????
            #??CD-k(??persisitent=None)?????RBM
            cost,updates=rbm.get_cost_updates(learning_rate,persistent=None,k=k)

            #??thenao??,???learning_rate???tensor??
            fn=theano.function(inputs=[index,theano.Param(learning_rate,default=0.1)],
                               outputs=cost,updates=updates,
                               givens={self.x:train_set_x[batch_begin:batch_end]})
            #?'fn'???list???
            pretrain_fns.append(fn)
        return pretrain_fns
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号