def build_all(self, param_avg=False):
"""Build all nodes."""
if self._has_built_all:
raise Exception('Only call build_all or build_eval once.')
self._has_built_all = True
with tf.device(self.get_device_fn()):
with tf.variable_scope(self.name):
inp_var = self.build_input()
output_var = self.build(inp_var)
loss_var = self.build_loss(inp_var, output_var)
train_step = self.build_optim(loss_var)
if param_avg:
ema_op, avg_var = self.get_average_var()
self._avg_var = avg_var
with tf.control_dependencies([train_step, ema_op]):
train_step = tf.no_op(name='train_step')
self.register_var('train_step', train_step)
return self
评论列表
文章目录