def build_optim(self, loss):
global_step = self.global_step
learn_rate = self.learn_rate
# We must calculate the mean of each gradient. Note that this is the
# synchronization point across all towers.
grads = self.average_gradients(self.tower_grads)
# Apply the gradients to adjust the shared variables.
apply_gradient_op = self.opt.apply_gradients(
grads, global_step=global_step)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(
0.999, global_step)
variables_averages_op = variable_averages.apply(
tf.trainable_variables())
# Group all updates to into a single train op.
train_op = tf.group(apply_gradient_op, variables_averages_op)
# for m in self.sub_models:
# self.log.info(m.device)
# self.log.fatal('haha')
return train_op
resnet_imagenet_model_multi_wrapper.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录