def set_optimizer(inputs, cost, tparams, constants, updates, extra_outs,
optimizer='sgd', optimizer_args=None,
**learning_args):
'''Sets the parameter update functions with optimizer.
Args:
inputs (T.tensor): input variables.
cost (T.scalar): cost
tparams (OrderedDict): directionary of tensor parameters
constants (list): list of constant tensors.
updates (theano.OrderedUpdates): updates.
extra_outs (list): list of extra output tensors.
optimizer (Optional[str]): optimizer string. See `utils.op` for details.
Defaults to `sgd`.
optimizer_args (Optional[dict]): optional arguments for optimizer.
**learning_args: extra kwargs for learning not used.
Returns:
theano.function: gradient function.
theano.function: update function.
dict: extra learning keyword arguments.
'''
if optimizer_args is None:
optimizer_args = dict()
grads = T.grad(cost, wrt=itemlist(tparams),
consider_constant=constants)
updates = theano.OrderedUpdates(updates)
lr = T.scalar(name='lr')
f_grad_shared, f_grad_updates = eval('op.' + optimizer)(
lr, tparams, grads, inputs, cost, extra_ups=updates,
extra_outs=extra_outs, **optimizer_args)
return f_grad_shared, f_grad_updates, learning_args
评论列表
文章目录