def compile_iter_fns(self, *args, **kwargs):
import theano
import time
start=time.time()
# f_pred_prob = theano.function([x, mask], pred, name='f_pred_prob')
self.f_pred = theano.function([self.x, self.mask], self.pred.argmax(axis=1), name='f_pred')
# f_cost = theano.function([x, mask, y], cost, name='f_cost')
import theano.tensor as tensor
grads = tensor.grad(self.cost, wrt=list(self.tparams.values()))
# f_grad = theano.function([x, mask, y], grads, name='f_grad')
lr = tensor.scalar(name='lr')
from theanompi.models.lstm import adadelta
self.f_grad_shared, self.f_update = adadelta(lr, self.tparams, grads,
self.x, self.mask, self.y, self.cost)
if self.rank==0: print('compile time %.3f' % (time.time()-start))
lstm_theanompi_outdated.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录