moving_average_optimizer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    train_op = self._optimizer.apply_gradients(
        grads_and_vars, global_step=global_step, name=name)
    var_list = [x[1] for x in grads_and_vars if x[0] is not None]
    self._variable_map = {}
    if self._sequential_update:
      with ops.control_dependencies([train_op]):
        ma_op = self._ema.apply(var_list)
    else:
      ma_op = self._ema.apply(var_list)

    for v in var_list:
      v_avg = self._ema.average(v)
      self._variable_map[v.op.name] = v_avg
      self._variable_map[v_avg.op.name] = v
    return control_flow_ops.group(train_op, ma_op, name="train_with_avg")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号