def _numeric_combine(x, fn, reduce_instance_dims=True, name=None):
"""Apply an analyzer with _NumericCombineSpec to given input."""
if not isinstance(x, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x)
if reduce_instance_dims:
# If reducing over all dimensions, result is scalar.
shape = ()
elif x.shape.dims is not None:
# If reducing over batch dimensions, with known shape, the result will be
# the same shape as the input, but without the batch.
shape = x.shape.as_list()[1:]
else:
# If reducing over batch dimensions, with unknown shape, the result will
# also have unknown shape.
shape = None
return combine_analyzer(
x, x.dtype, shape, _NumPyCombinerSpec(fn, reduce_instance_dims),
name if name is not None else fn.__name__)
评论列表
文章目录