head.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _multi_head(heads, loss_weights=None):
  """Creates a MultiHead stemming from same logits/hidden layer.

  Args:
    heads: list of _Head objects.
    loss_weights: optional list of weights to be used to combine losses from
        each head. All losses are weighted equally if not provided.

  Returns:
    A _Head instance that combines multiple heads.

  Raises:
    ValueError: if heads and loss_weights have different size.
  """
  if loss_weights:
    if len(loss_weights) != len(heads):
      raise ValueError("heads and loss_weights must have same size")

  def _weighted_loss_combiner(losses):
    if loss_weights:
      if len(losses) != len(loss_weights):
        raise ValueError("losses and loss_weights must have same size")
      weighted_losses = []
      for loss, weight in zip(losses, loss_weights):
        weighted_losses.append(math_ops.multiply(loss, weight))
      return math_ops.add_n(weighted_losses)
    else:
      return math_ops.add_n(losses)

  return _MultiHead(heads, loss_combiner=_weighted_loss_combiner)


# TODO(zakaria): Make the classes public once we are ready for users to subclass
#   them.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号