sgd.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:corelm 作者: nusnlp 项目源码 文件源码
def __init__(self, classifier, criterion, learning_rate, trainset, clip_threshold=0):
        self.eta = learning_rate
        self.is_weighted = trainset.is_weighted

        if clip_threshold > 0:
            gparams = [T.clip(T.grad(criterion.cost, param), -clip_threshold, clip_threshold) for param in classifier.params]
        else:
            gparams = [T.grad(criterion.cost, param) for param in classifier.params]

        lr = T.fscalar()

        updates = [
            (param, param - lr * gparam)
            for param, gparam in zip(classifier.params, gparams)
        ]

        index = T.lscalar()     # index to a [mini]batch
        x = classifier.input
        y = criterion.y

        if self.is_weighted: 
            w = criterion.w
            self.step_func = theano.function(
                inputs=[index, lr],
                outputs=[criterion.cost] + gparams,
                updates=updates,
                givens={
                    x: trainset.get_x(index),
                    y: trainset.get_y(index),
                    w: trainset.get_w(index)
                }
            )
        else:
            self.step_func = theano.function(
                inputs=[index, lr],
                outputs=[criterion.cost] + gparams,
                updates=updates,
                givens={
                    x: trainset.get_x(index),
                    y: trainset.get_y(index)
                }
            )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号