def _broadcast_weights(weights, values):
"""Broadcast `weights` to the same shape as `values`.
This returns a version of `weights` following the same broadcast rules as
`mul(weights, values)`. When computing a weighted average, use this function
to broadcast `weights` before summing them; e.g.,
`reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
Args:
weights: `Tensor` whose rank is either 0, or the same rank as `values`, and
must be broadcastable to `values` (i.e., all dimensions must be either
`1`, or the same as the corresponding `values` dimension).
values: `Tensor` of any shape.
Returns:
`weights` broadcast to `values` shape.
"""
with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope:
weights_shape = weights.get_shape()
values_shape = values.get_shape()
if (weights_shape.is_fully_defined() and
values_shape.is_fully_defined() and
weights_shape.is_compatible_with(values_shape)):
return weights
with ops.control_dependencies((_assert_weights_rank(weights, values),)):
return math_ops.multiply(
weights, array_ops.ones_like(values), name=scope)
metric_ops.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录