def _check_shape(value, expected_shape):
"""Check the shape of Tensor `value`, via shape inference and assertions.
Args:
value: A Tensor, possibly with shape associated shape information.
expected_shape: a `TensorShape`, list of `int32`, or a vector `Tensor`.
Returns:
new_value: A Tensor matching `value`. Accessing this tensor tests
assertions on its shape. If expected_shape is not a `Tensor`, then
new_value's shape has been set.
Raises:
ValueError: if `expected_shape` is not a `Tensor` and the shape of `value`
is known and is not equal to `expected_shape`.
"""
assert isinstance(value, ops.Tensor)
if isinstance(expected_shape, tensor_shape.TensorShape):
expected_shape = expected_shape.as_list()
if isinstance(expected_shape, ops.Tensor):
expected_shape_value = tensor_util.constant_value(expected_shape)
if expected_shape_value is not None:
expected_shape = [int(d) for d in expected_shape_value]
if isinstance(expected_shape, ops.Tensor):
value = _check_rank(value, array_ops.size(expected_shape))
else:
value = _check_rank(value, len(expected_shape))
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(expected_shape, array_ops.shape(
value))), [string_ops.string_join([
"Shape of tensor %s should be: " % value.name,
string_ops.as_string(expected_shape), ", shape received: ",
string_ops.as_string(array_ops.shape(value))
])])
]):
new_value = array_ops.identity(value, name="shape_checked")
if not isinstance(expected_shape, ops.Tensor):
try:
new_value.set_shape(new_value.get_shape().merge_with(expected_shape))
except ValueError as e:
raise ValueError("Shape check failed for %s: %s"
% (value.name, str(e)))
return new_value
评论列表
文章目录