def get_updates(self, learning_rate, params, grads, lr_scalers):
"""Compute the parameters' updates.
"""
if self._first_time:
self.mean_square_grads = [
sharedX_mtx(
param.get_value() * 0.,
name='mean_square_grad_'+param.name,
borrow=True) for param in params]
self._first_time = False
updates = []
for (param, grad, mean_square_grad, lr_sc) in zip(
params, grads, self.mean_square_grads, lr_scalers):
new_mean_square_grad = (
self.decay * mean_square_grad + (1-self.decay) * T.sqr(grad))
# the update
rms_grad_t = T.sqrt(new_mean_square_grad)
rms_grad_t = T.maximum(rms_grad_t, self.epsilon)
lr_scaled = learning_rate * lr_sc
delta_x_t = - lr_scaled * grad / rms_grad_t
new_param = param + delta_x_t
# updates
if self.max_colm_norm and param.name in ["W", "w"]:
new_param_final = norm_constraint(tensor_var=new_param,
max_norm=self.max_norm)
else:
new_param_final = new_param
updates.append((param, new_param_final))
updates.append((mean_square_grad, new_mean_square_grad))
return updates
评论列表
文章目录