def _broadcast_weights(weights, values):
"""Broadcast `weights` to the same shape as `values`.
This returns a version of `weights` following the same broadcast rules as
`mul(weights, values)`. When computing a weighted average, use this function
to broadcast `weights` before summing them; e.g.,
`reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
Args:
weights: `Tensor` whose shape is broadcastable to `values`.
values: `Tensor` of any shape.
Returns:
`weights` broadcast to `values` shape.
"""
weights_shape = weights.get_shape()
values_shape = values.get_shape()
if (weights_shape.is_fully_defined() and
values_shape.is_fully_defined() and
weights_shape.is_compatible_with(values_shape)):
return weights
return math_ops.mul(
weights, array_ops.ones_like(values), name='broadcast_weights')
评论列表
文章目录