def _apply_sparse(self, grad, var):
lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
clip_multiplier_t = math_ops.cast(self.clip_multiplier_t, var.dtype.base_dtype)
clip_epsilon_t = math_ops.cast(self.clip_epsilon_t, var.dtype.base_dtype)
v = self.get_slot(var, "v")
v_slice = array_ops.gather(v, grad.indices)
#clip gradient so that each value exceeds its previous maximum by no more than clip_multiplier
clipped_values = grad.values
if self.clip_gradients:
clipVal = v_slice * clip_multiplier_t + clip_epsilon_t
clipped_values = clip_ops.clip_by_value(grad.values, -clipVal, clipVal)
# m := beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_t_values = beta1_t * array_ops.gather(m, grad.indices) + (1 - beta1_t) * clipped_values
m_t = state_ops.scatter_update(m, grad.indices, m_t_values, use_locking=self._use_locking)
# v := max(beta2 * v , abs(grad))
v_t_values = math_ops.maximum(beta2_t * v_slice, math_ops.abs(clipped_values))
v_t = state_ops.scatter_update(v, grad.indices, v_t_values, use_locking=self._use_locking)
# variable -= learning_rate * m_t / (epsilon_t + v_t)
# we do not use bias-correction term for the first moment; it does not give observable benefit
var_update = state_ops.scatter_sub(var, grad.indices,
lr_t * m_t_values / (v_t_values + epsilon_t),
use_locking=self._use_locking)
return control_flow_ops.group(var_update, v_t, m_t)
评论列表
文章目录