def gradient_descent(self, loss):
"""Momentum GD with gradient clipping."""
grad = T.grad(loss, self.params)
self.momentum_velocity_ = [0.] * len(grad)
grad_norm = T.sqrt(sum(map(lambda x: T.sqr(x).sum(), grad)))
updates = OrderedDict()
not_finite = T.or_(T.isnan(grad_norm), T.isinf(grad_norm))
scaling_den = T.maximum(5.0, grad_norm)
for n, (param, grad) in enumerate(zip(self.params, grad)):
grad = T.switch(not_finite, 0.1 * param,
grad * (5.0 / scaling_den))
velocity = self.momentum_velocity_[n]
update_step = self.momentum * velocity - self.learning_rate * grad
self.momentum_velocity_[n] = update_step
updates[param] = param + update_step
return updates
评论列表
文章目录