careful.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:crayimage 作者: yandexdataschool 项目源码 文件源码
def cruel_rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6,
                  grad_clipping=1.0e-2, param_clipping=1.0e-2):
  """
  A version of careful RMSProp for Wassershtein GAN. 
  :param epsilon: small number for computational stability.
  :param grad_clipping: maximal norm of gradient, if norm of the actual gradient exceeds this values it is rescaled.
  :param param_clipping: after each update all params are clipped to [-`param_clipping`, `param_clipping`].
  :return: 
  """
  grads = get_or_compute_grads(loss_or_grads, params)
  updates = OrderedDict()
  grads = total_norm_constraint(grads, max_norm=grad_clipping, epsilon=epsilon)

  # Using theano constant to prevent upcasting of float32
  one = T.constant(1)

  for param, grad in zip(params, grads):
    value = param.get_value(borrow=True)
    accu = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                         broadcastable=param.broadcastable)
    accu_new = rho * accu + (one - rho) * grad ** 2
    updates[accu] = accu_new

    updated = param - (learning_rate * grad / T.sqrt(accu_new + epsilon))

    if param_clipping is not None:
      updates[param] = T.clip(updated, -param_clipping, param_clipping)
    else:
      updates[param] = updated

  return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号