def _apply_sparse(self, grad, var):
beta1_power = tf.cast(self._beta1_power, var.dtype.base_dtype)
beta2_power = tf.cast(self._beta2_power, var.dtype.base_dtype)
lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power))
# m := beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, "m")
m_t = tf.scatter_update(m, grad.indices,
beta1_t * tf.gather(m, grad.indices) +
(1 - beta1_t) * grad.values,
use_locking=self._use_locking)
# v := beta2 * v + (1 - beta2) * (g_t * g_t)
v = self.get_slot(var, "v")
v_t = tf.scatter_update(v, grad.indices,
beta2_t * tf.gather(v, grad.indices) +
(1 - beta2_t) *
tf.square(grad.values),
use_locking=self._use_locking)
# variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
m_t_slice = tf.gather(m_t, grad.indices)
v_t_slice = tf.gather(v_t, grad.indices)
denominator_slice = tf.sqrt(v_t_slice) + epsilon_t
var_update = tf.scatter_sub(var, grad.indices,
lr * m_t_slice / denominator_slice,
use_locking=self._use_locking)
return tf.group(var_update, m_t, v_t)
评论列表
文章目录