def apply_gradients(self, grads_and_vars, global_step=None, name=None):
train_op = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
var_list = [x[1] for x in grads_and_vars if x[0] is not None]
self._variable_map = {}
if self._sequential_update:
with ops.control_dependencies([train_op]):
ma_op = self._ema.apply(var_list)
else:
ma_op = self._ema.apply(var_list)
for v in var_list:
v_avg = self._ema.average(v)
self._variable_map[v.op.name] = v_avg
self._variable_map[v_avg.op.name] = v
return control_flow_ops.group(train_op, ma_op, name="train_with_avg")
评论列表
文章目录