def tower_loss(scope):
"""Calculate the total loss on a single tower running the MNIST model.
Args:
scope: unique prefix string identifying the MNIST tower, e.g. 'tower_0'
Returns:
Tensor of shape [] containing the total loss for a batch of data
"""
# Get images and labels for MSNIT.
images, labels = model.inputs(FLAGS.batch_size)
# Build inference Graph.
logits = model.inference(images, keep_prob=0.5)
# Build the portion of the Graph calculating the losses. Note that we will
# assemble the total_loss using a custom function below.
_ = model.loss(logits, labels)
# Assemble all of the losses for the current tower only.
losses = tf.get_collection('losses', scope)
# Calculate the total loss for the current tower.
total_loss = tf.add_n(losses, name='total_loss')
# Attach a scalar summary to all individual losses and the total loss; do
# the same for the averaged version of the losses.
if (FLAGS.tb_logging):
for l in losses + [total_loss]:
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU
# training session. This helps the clarity of presentation on
# tensorboard.
loss_name = re.sub('%s_[0-9]*/' % model.TOWER_NAME, '', l.op.name)
tf.summary.scalar(loss_name, l)
return total_loss
mnist_multi_gpu_train.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录