def _count_condition(values, weights=None, metrics_collections=None,
updates_collections=None):
"""Sums the weights of cases where the given values are True.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
values: A `bool` `Tensor` of arbitrary size.
weights: An optional `Tensor` whose shape is broadcastable to `values`.
metrics_collections: An optional list of collections that the metric
value variable should be added to.
updates_collections: An optional list of collections that the metric update
ops should be added to.
Returns:
value_tensor: A tensor representing the current value of the metric.
update_op: An operation that accumulates the error from a batch of data.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match `values`,
or if either `metrics_collections` or `updates_collections` are not a list
or tuple.
"""
check_ops.assert_type(values, dtypes.bool)
count = _create_local('count', shape=[])
values = math_ops.to_float(values)
if weights is not None:
weights = math_ops.to_float(weights)
values = math_ops.mul(values, weights)
value_tensor = array_ops.identity(count)
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if metrics_collections:
ops.add_to_collections(metrics_collections, value_tensor)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return value_tensor, update_op
评论列表
文章目录