optimizer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:monogreedy 作者: jinjunqi 项目源码 文件源码
def sgd_optimizer(model, lr=0.001, momentum=0.9):
    lr = theano.shared(np.array(lr).astype(theano.config.floatX))
    # Make sure momentum is a sane value
    assert momentum < 1 and momentum >= 0
    # the updates of SGD with momentum
    updates = []
    grads = T.grad(model.costs[0], model.params)
    for param, grad in zip(model.params, grads):
        param_update = theano.shared(param.get_value()*0.)
        updates.append((param, param - lr * param_update))
        updates.append((param_update, momentum*param_update + (1. - momentum)*grad))

    train_func = theano.function(model.inputs, model.costs, updates=updates)
    valid_func = theano.function(model.inputs, model.costs)

    return train_func, valid_func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号