optimizer.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _apply_sparse(self, grad, var):
        beta1_power = tf.cast(self._beta1_power, var.dtype.base_dtype)
        beta2_power = tf.cast(self._beta2_power, var.dtype.base_dtype)
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
        epsilon_t = tf.cast(self._epsilon_t, var.dtype.base_dtype)
        lr = (lr_t * tf.sqrt(1 - beta2_power) / (1 - beta1_power))

        # m := beta1 * m + (1 - beta1) * g_t
        m = self.get_slot(var, "m")
        m_t = tf.scatter_update(m, grad.indices,
                                beta1_t * tf.gather(m, grad.indices) +
                                (1 - beta1_t) * grad.values,
                                use_locking=self._use_locking)

        # v := beta2 * v + (1 - beta2) * (g_t * g_t)
        v = self.get_slot(var, "v")
        v_t = tf.scatter_update(v, grad.indices,
                                beta2_t * tf.gather(v, grad.indices) +
                                (1 - beta2_t) *
                                tf.square(grad.values),
                                use_locking=self._use_locking)

        # variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))
        m_t_slice = tf.gather(m_t, grad.indices)
        v_t_slice = tf.gather(v_t, grad.indices)
        denominator_slice = tf.sqrt(v_t_slice) + epsilon_t
        var_update = tf.scatter_sub(var, grad.indices,
                                    lr * m_t_slice / denominator_slice,
                                    use_locking=self._use_locking)
        return tf.group(var_update, m_t, v_t)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号