def _apply_dense(self, grad, weight):
learning_rate_t = tf.cast(self._lr_t, weight.dtype.base_dtype)
mu_t = tf.cast(self._mu_t, weight.dtype.base_dtype)
norm_t = tf.cast(self._norm_t, weight.dtype.base_dtype)
momentum = self.get_slot(weight, "a")
norm = self.get_slot(weight, "n")
if momentum.get_shape().ndims == 2:
momentum_mean = tf.reduce_mean(momentum, axis=1, keep_dims=True)
elif momentum.get_shape().ndims == 1:
momentum_mean = momentum
else:
momentum_mean = momentum
norm_update = learning_rate_t / norm + norm
norm_t = tf.assign(norm_t, norm_update)
momentum_update = (grad / norm_t) + (mu_t * momentum_mean)
momentum_t = tf.assign(momentum, momentum_update,
use_locking=self._use_locking)
weight_update = learning_rate_t * momentum_t
weight_t = tf.assign_sub(
weight, weight_update, use_locking=self._use_locking)
return tf.group(*[weight_t, norm_t, momentum_t])
评论列表
文章目录