careful.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:crayimage 作者: yandexdataschool 项目源码 文件源码
def careful_rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6, grad_clipping=1.0e-2):
  """
  RMSProp with gradient clipping.
  :param grad_clipping: maximal norm of gradient, if norm of the actual gradient exceeds this values it is rescaled.
  :return: updates
  """
  grads = get_or_compute_grads(loss_or_grads, params)
  updates = OrderedDict()
  grads = total_norm_constraint(grads, max_norm=grad_clipping, epsilon=epsilon)

  # Using theano constant to prevent upcasting of float32
  one = T.constant(1)

  for param, grad in zip(params, grads):
    value = param.get_value(borrow=True)
    accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                         broadcastable=param.broadcastable)
    accu_new = rho * accu + (one - rho) * grad ** 2
    updates[accu] = accu_new
    updates[param] = param - (learning_rate * grad /
                              T.sqrt(accu_new + epsilon))

  return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号