def optimize_clones(clones, optimizer,
regularization_losses=None,
**kwargs):
"""Compute clone losses and gradients for the given list of `Clones`.
Note: The regularization_losses are added to the first clone losses.
Args:
clones: List of `Clones` created by `create_clones()`.
optimizer: An `Optimizer` object.
regularization_losses: Optional list of regularization losses. If None it
will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
exclude them.
**kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
Returns:
A tuple (total_loss, grads_and_vars).
- total_loss: A Tensor containing the average of the clone losses including
the regularization loss.
- grads_and_vars: A List of tuples (gradient, variable) containing the sum
of the gradients for each variable.
"""
grads_and_vars = []
clones_losses = []
num_clones = len(clones)
if regularization_losses is None:
regularization_losses = tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES)
for clone in clones:
with tf.name_scope(clone.scope):
clone_loss, clone_grad = _optimize_clone(
optimizer, clone, num_clones, regularization_losses, **kwargs)
if clone_loss is not None:
clones_losses.append(clone_loss)
grads_and_vars.append(clone_grad)
# Only use regularization_losses for the first clone
regularization_losses = None
# Compute the total_loss summing all the clones_losses.
total_loss = tf.add_n(clones_losses, name='total_loss')
# Sum the gradients accross clones.
grads_and_vars = _sum_clones_gradients(grads_and_vars)
return total_loss, grads_and_vars
评论列表
文章目录