def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
**kwargs):
"""Compute losses and gradients for a single clone.
Args:
optimizer: A tf.Optimizer object.
clone: A Clone namedtuple.
num_clones: The number of clones being deployed.
regularization_losses: Possibly empty list of regularization_losses
to add to the clone losses.
**kwargs: Dict of kwarg to pass to compute_gradients().
Returns:
A tuple (clone_loss, clone_grads_and_vars).
- clone_loss: A tensor for the total loss for the clone. Can be None.
- clone_grads_and_vars: List of (gradient, variable) for the clone.
Can be empty.
"""
sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
clone_grad = None
if sum_loss is not None:
with tf.device(clone.device):
clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
return sum_loss, clone_grad
评论列表
文章目录