lstm_theanompi_outdated.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Theano-MPI 作者: uoguelph-mlrg 项目源码 文件源码
def compile_iter_fns(self, *args, **kwargs):

        import theano

        import time
        start=time.time()

        # f_pred_prob = theano.function([x, mask], pred, name='f_pred_prob')
        self.f_pred = theano.function([self.x, self.mask], self.pred.argmax(axis=1), name='f_pred')

        # f_cost = theano.function([x, mask, y], cost, name='f_cost')
        import theano.tensor as tensor
        grads = tensor.grad(self.cost, wrt=list(self.tparams.values()))
        # f_grad = theano.function([x, mask, y], grads, name='f_grad')

        lr = tensor.scalar(name='lr')

        from theanompi.models.lstm import adadelta
        self.f_grad_shared, self.f_update = adadelta(lr, self.tparams, grads,
                                         self.x, self.mask, self.y, self.cost)

        if self.rank==0: print('compile time %.3f' % (time.time()-start))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号