base_optimizer.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def _sparse_moving_average(self, x_tm1, idxs, a_t_, name, beta=.9):
    """"""

    b_tm1 = self.get_accumulator(x_tm1, '%s' % name)
    b_tm1_ = tf.gather(b_tm1, idxs)
    shape = self.get_variable_shape(x_tm1)
    tm1 = self.get_accumulator(x_tm1, '%s/tm1' % name, shape=[shape[0]]+[1]*(len(shape)-1))
    tm1_ = tf.gather(tm1, idxs)
    t = tf.scatter_add(tm1, idxs, tf.ones_like(tm1_))
    t_ = tf.gather(t, idxs)
    if beta < 1:
      beta_t = tf.convert_to_tensor(beta, name='%s/decay' % name)
      beta_t_ = beta_t * (1-beta_t**tm1_) / (1-beta_t**t_)
    else:
      beta_t_ = tm1_/t_
    b_t = tf.scatter_update(b_tm1, idxs, beta_t_*b_tm1_)
    b_t = tf.scatter_add(b_t, idxs, (1-beta_t_)*a_t_)
    return b_t, t

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号