def _apply_sparse(self, grad, var):
lr = (self._lr_t *
math_ops.sqrt(1 - self._beta2_power)
/ (1 - self._beta1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_scaled_g_values = grad.values * (1 - self._beta1_t)
m_t = state_ops.assign(m, m * self._beta1_t,
use_locking=self._use_locking)
m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
use_locking=self._use_locking)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
v_t = state_ops.assign(v, v * self._beta2_t, use_locking=self._use_locking)
v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
use_locking=self._use_locking)
v_sqrt = tf.pow(v_t, self._pow_t)
var_update = state_ops.assign_sub(var,
lr * m_t / (v_sqrt + self._epsilon_t),
use_locking=self._use_locking)
# regularization
var_update = state_ops.assign_sub(var_update,
self._sparse_regularization * var,
use_locking=self._use_locking)
return control_flow_ops.group(*[var_update, m_t, v_t])
评论列表
文章目录