metric_ops.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _broadcast_weights(weights, values):
  """Broadcast `weights` to the same shape as `values`.

  This returns a version of `weights` following the same broadcast rules as
  `mul(weights, values)`. When computing a weighted average, use this function
  to broadcast `weights` before summing them; e.g.,
  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.

  Args:
    weights: `Tensor` whose shape is broadcastable to `values`.
    values: `Tensor` of any shape.

  Returns:
    `weights` broadcast to `values` shape.
  """
  weights_shape = weights.get_shape()
  values_shape = values.get_shape()
  if (weights_shape.is_fully_defined() and
      values_shape.is_fully_defined() and
      weights_shape.is_compatible_with(values_shape)):
    return weights
  return math_ops.mul(
      weights, array_ops.ones_like(values), name='broadcast_weights')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号