def _num_present(losses, weights, per_batch=False):
"""Computes the number of elements in the loss function induced by `weights`.
A given weights tensor induces different numbers of usable elements in the
`losses` tensor. The `weights` tensor is broadcast across `losses` for all
possible dimensions. For example, if `losses` is a tensor of dimension
[4, 5, 6, 3] and `weights` is a tensor of size [4, 5], then `weights` is, in
effect, tiled to match the size of `losses`. Following this effective tile,
the total number of present elements is the number of non-zero weights.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
per_batch: Whether to return the number of elements per batch or as a sum
total.
Returns:
The number of present (non-zero) elements in the losses tensor. If
`per_batch` is True, the value is returned as a tensor of size
[batch_size]. Otherwise, a single scalar tensor is returned.
"""
# If weights is a scalar, its easy to compute:
if weights.get_shape().ndims == 0:
batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
[0], [1]), [])
num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
math_ops.to_float(batch_size))
num_per_batch = array_ops.where(math_ops.equal(weights, 0),
0.0, num_per_batch)
num_per_batch = math_ops.multiply(array_ops.ones(
array_ops.reshape(batch_size, [1])), num_per_batch)
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
# First, count the number of nonzero weights:
if weights.get_shape().ndims >= 1:
reduction_indices = list(range(1, weights.get_shape().ndims))
num_nonzero_per_batch = math_ops.reduce_sum(
math_ops.to_float(math_ops.not_equal(weights, 0)),
reduction_indices=reduction_indices)
# Next, determine the number of elements that weights would broadcast to:
broadcast_dims = array_ops.slice(array_ops.shape(losses),
[weights.get_shape().ndims], [-1])
num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast)
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
loss_ops.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录