radam_optimizer.py 文件源码

python
阅读 84 收藏 0 点赞 0 评论 0

项目:Sing_Par 作者: wanghm92 项目源码 文件源码
def _apply_sparse(self, cache):
    """"""

    x_tm1, g_t, idxs = cache['x_tm1'], cache['g_t'], cache['idxs']
    idxs, idxs_ = tf.unique(idxs)
    g_t_ = tf.unsorted_segment_sum(g_t, idxs_, tf.size(idxs))
    updates = cache['updates']

    if self.mu > 0:
      m_t, t_m = self._sparse_moving_average(x_tm1, idxs, g_t_, 'm', beta=self.mu)
      m_t_ = tf.gather(m_t, idxs)
      m_bar_t_ = (1-self.gamma) * m_t_ + self.gamma * g_t_
      updates.extend([m_t, t_m])
    else:
      m_bar_t_ = g_t_

    if self.nu > 0:
      v_t, t_v = self._sparse_moving_average(x_tm1, idxs, g_t_**2, 'v', beta=self.nu)
      v_t_ = tf.gather(v_t, idxs)
      v_bar_t_ = tf.sqrt(v_t_ + self.epsilon)
      updates.extend([v_t, t_v])
    else:
      v_bar_t_ = 1

    s_t_ = self.learning_rate * m_bar_t_ / v_bar_t_
    cache['s_t'] = s_t_
    cache['g_t'] = g_t_
    cache['idxs'] = idxs
    return cache
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号