model_deploy.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:segmentation-models 作者: desimone 项目源码 文件源码
def optimize_clones(clones, optimizer, regularization_losses=None, **kwargs):
    """Compute clone losses and gradients for the given list of `Clones`.

      Note: The regularization_losses are added to the first clone losses.

      Args:
       clones: List of `Clones` created by `create_clones()`.
       optimizer: An `Optimizer` object.
       regularization_losses: Optional list of regularization losses. If None it
         will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
         exclude them.
       **kwargs: Optional list of keyword arguments to pass to `compute_gradients`.

      Returns:
       A tuple (total_loss, grads_and_vars).
         - total_loss: A Tensor containing the average of the clone losses including
           the regularization loss.
         - grads_and_vars: A List of tuples (gradient, variable) containing the sum
           of the gradients for each variable.

      """
    grads_and_vars = []
    clones_losses = []
    num_clones = len(clones)
    if regularization_losses is None:
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
    for clone in clones:
        with tf.name_scope(clone.scope):
            clone_loss, clone_grad = _optimize_clone(optimizer, clone,
                                                     num_clones,
                                                     regularization_losses,
                                                     **kwargs)
            if clone_loss is not None:
                clones_losses.append(clone_loss)
                grads_and_vars.append(clone_grad)
            # Only use regularization_losses for the first clone
            regularization_losses = None
    # Compute the total_loss summing all the clones_losses.
    total_loss = tf.add_n(clones_losses, name='total_loss')
    # Sum the gradients accross clones.
    grads_and_vars = _sum_clones_gradients(grads_and_vars)
    return total_loss, grads_and_vars
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号