tagger.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deep_srl 作者: luheng 项目源码 文件源码
def get_loss_function(self):
    """ We should feed in non-dimshuffled inputs x0, mask0 and y0.
    """
    loss = CrossEntropyLoss().connect(self.scores, self.mask, self.y)
    grads = gradient_clipping(tensor.grad(loss, self.params),
                  self.max_grad_norm)
    updates = adadelta(self.params, grads)

    return theano.function([self.x0, self.mask0, self.y0], loss,
                 name='f_loss',
                 updates=updates,
                 on_unused_input='warn',
                 givens=({self.is_train: numpy.cast['int8'](1)}))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号