base_optimizer.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:UnstableParser 作者: tdozat 项目源码 文件源码
def _finish(self, caches):
    """"""

    if self.clip > 0:
      S_t = [cache['s_t'] for cache in caches]
      S_t, _ = tf.clip_by_global_norm(S_t, self.clip)
      for cache, s_t in zip(caches, S_t):
        cache['s_t'] = s_t

    for cache in caches:
      x_tm1 = cache['x_tm1']
      s_t = cache['s_t']
      updates = cache['updates']
      with tf.name_scope('update_' + x_tm1.op.name), tf.device(x_tm1.device):
        if 'idxs' in cache:
          idxs = cache['idxs']
          x_t = tf.scatter_sub(x_tm1, idxs, s_t)
          if self.chi > 0:
            x_t_ = tf.gather(x_t, idxs)
            x_bar_t, t_x_bar = self._sparse_moving_average(x_tm1, idxs, x_t_, 'x', beta=self.chi)
        else:
          x_t = tf.assign_sub(x_tm1, s_t)
          if self.chi > 0:
            x_bar_t, t_x_bar = self._dense_moving_average(x_tm1, x_t, 'x', beta=self.chi)
      updates.append(x_t)
      if self.chi > 0:
        updates.extend([x_bar_t, t_x_bar])

    update_ops = [tf.group(*cache['updates']) for cache in caches]
    return tf.group(*update_ops, name='update')

  #==============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号