def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
python类reduce_all()的实例源码
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
where_sampling = math_ops.cast(
array_ops.where(sample_ids > -1), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(sample_ids <= -1), dtypes.int32)
where_sampling_flat = array_ops.reshape(where_sampling, [-1])
where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1])
sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat)
inputs_not_sampling = array_ops.gather(
base_next_inputs, where_not_sampling_flat)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
where_sampling = math_ops.cast(
array_ops.where(sample_ids > -1), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(sample_ids <= -1), dtypes.int32)
where_sampling_flat = array_ops.reshape(where_sampling, [-1])
where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1])
sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat)
inputs_not_sampling = array_ops.gather(
base_next_inputs, where_not_sampling_flat)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
def next_inputs(self, sample_ids,name=None):
finished = math_ops.equal(sample_ids, self.config.eos_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: tf.nn.embedding_lookup(self.target_embedding, tf.tile([self.config.eos_token], [self.config.beam_width])),
lambda: tf.nn.embedding_lookup(self.target_embedding, sample_ids))
return all_finished, next_inputs
def all(x, axis=None, keepdims=False):
"""Bitwise reduction (logical AND).
Arguments:
x: Tensor or variable.
axis: axis along which to perform the reduction.
keepdims: whether the drop or broadcast the reduction axes.
Returns:
A uint8 tensor (0s and 1s).
"""
axis = _normalize_axis(axis, ndim(x))
x = math_ops.cast(x, dtypes_module.bool)
return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size.
Slices a shape `size` portion out of `value` at a uniformly chosen offset.
Requires `value.shape >= size`.
If a dimension should not be cropped, pass the full size of that dimension.
For example, RGB images can be cropped with
`size = [crop_height, crop_width, 3]`.
Args:
value: Input tensor to crop.
size: 1-D tensor with size the rank of `value`.
seed: Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
name: A name for this operation (optional).
Returns:
A cropped tensor of the same rank as `value` and shape `size`.
"""
# TODO(shlens): Implement edge case to guarantee output size dimensions.
# If size > value.shape, zero pad the result so that it always has shape
# exactly size.
with ops.name_scope(name, "random_crop", [value, size]) as name:
value = ops.convert_to_tensor(value, name="value")
size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
shape = array_ops.shape(value)
check = control_flow_ops.Assert(
math_ops.reduce_all(shape >= size),
["Need value.shape >= size, got ", shape, size],
summarize=1000)
shape = control_flow_ops.with_dependencies([check], shape)
limit = shape - size + 1
offset = random_uniform(
array_ops.shape(shape),
dtype=size.dtype,
maxval=size.dtype.max,
seed=seed) % limit
return array_ops.slice(value, offset, size, name=name)
def _all_equal(tensor0, tensor1):
with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
return math_ops.reduce_all(
math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.name_scope(name, "assert_close", [x, y, data]):
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return control_flow_ops.Assert(
condition, data, summarize=summarize)
def _all_equal(tensor0, tensor1):
with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
return math_ops.reduce_all(
math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.name_scope(name, "assert_close", [x, y, data]):
tol = np.finfo(x.dtype.as_numpy_dtype).eps
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return control_flow_ops.Assert(
condition, data, summarize=summarize)
def do_center_crop(value, size, name=None):
"""Randomly crops a tensor to a given size.
Slices a shape `size` portion out of `value` at a uniformly chosen offset.
Requires `value.shape >= size`.
If a dimension should not be cropped, pass the full size of that dimension.
For example, RGB images can be cropped with
`size = [crop_height, crop_width, 3]`.
Args:
value: Input tensor to crop.
size: 1-D tensor with size the rank of `value`.
seed: Python integer. Used to create a random seed. See
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
name: A name for this operation (optional).
Returns:
A cropped tensor of the same rank as `value` and shape `size`.
"""
# TODO(shlens): Implement edge case to guarantee output size dimensions.
# If size > value.shape, zero pad the result so that it always has shape
# exactly size.
from tensorflow.python.framework import dtypes
with ops.op_scope([value, size], name, "center_crop") as name:
value = ops.convert_to_tensor(value, name="value")
size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
shape = array_ops.shape(value)
check = logging_ops.Assert(
math_ops.reduce_all(shape >= size),
["Need value.shape >= size, got ", shape, size])
shape = control_flow_ops.with_dependencies([check], shape)
limit = shape - size + 1
offset = tf.random_uniform(
array_ops.shape(shape),
dtype=size.dtype,
maxval=size.dtype.max,
seed=0) % limit
offset2 = shape // 2 - size // 2
#import ipdb; ipdb.set_trace()
return array_ops.slice(value, offset, size, name=name)
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
where_sampling = math_ops.cast(
array_ops.where(sample_ids > -1), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(sample_ids <= -1), dtypes.int32)
where_sampling_flat = array_ops.reshape(where_sampling, [-1])
where_not_sampling_flat = array_ops.reshape(where_not_sampling, [-1])
sample_ids_sampling = array_ops.gather(sample_ids, where_sampling_flat)
inputs_not_sampling = array_ops.gather(
base_next_inputs, where_not_sampling_flat)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
tensor_util.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 23
收藏 0
点赞 0
评论 0
def _all_equal(tensor0, tensor1):
with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
return math_ops.reduce_all(
math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
ops_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 22
收藏 0
点赞 0
评论 0
def test_name(self):
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
self.assertIn('lt_reduce_all', result_lt.name)
ops_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 25
收藏 0
点赞 0
评论 0
def test(self):
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
golden_lt = core.LabeledTensor(
math_ops.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
transformed_distribution.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 23
收藏 0
点赞 0
评论 0
def _logical_and(*args):
"""Convenience function which attempts to statically `reduce_all`."""
args_ = [_static_value(x) for x in args]
if any(x is not None and not bool(x) for x in args_):
return constant_op.constant(False)
if all(x is not None and bool(x) for x in args_):
return constant_op.constant(True)
if len(args) == 2:
return math_ops.logical_and(*args)
return math_ops.reduce_all(args)
distribution_util.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 19
收藏 0
点赞 0
评论 0
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.name_scope(name, "assert_close", [x, y, data]):
tol = np.finfo(x.dtype.as_numpy_dtype).eps
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return control_flow_ops.Assert(
condition, data, summarize=summarize)
uniform_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 20
收藏 0
点赞 0
评论 0
def testUniformSamplePdf(self):
with self.test_session():
a = 10.0
b = [11.0, 100.0]
uniform = uniform_lib.Uniform(a, b)
self.assertTrue(
math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0).eval())
reduce_ops_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 20
收藏 0
点赞 0
评论 0
def testReduceAll(self):
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledOutputTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
"""Concatenate outputs with auxiliary inputs, if they exist."""
if self._auxiliary_input_tas is None:
return outputs_
next_time = time + 1
auxiliary_inputs = nest.map_structure(
lambda ta: ta.read(next_time), self._auxiliary_input_tas)
if indices is not None:
auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
return nest.map_structure(
lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
where_sampling = math_ops.cast(
array_ops.where(sample_ids), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledOutputTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
"""Concatenate outputs with auxiliary inputs, if they exist."""
if self._auxiliary_input_tas is None:
return outputs_
next_time = time + 1
auxiliary_inputs = nest.map_structure(
lambda ta: ta.read(next_time), self._auxiliary_input_tas)
if indices is not None:
auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
return nest.map_structure(
lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
where_sampling = math_ops.cast(
array_ops.where(sample_ids), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
def _check_shape(value, expected_shape):
"""Check the shape of Tensor `value`, via shape inference and assertions.
Args:
value: A Tensor, possibly with shape associated shape information.
expected_shape: a `TensorShape`, list of `int32`, or a vector `Tensor`.
Returns:
new_value: A Tensor matching `value`. Accessing this tensor tests
assertions on its shape. If expected_shape is not a `Tensor`, then
new_value's shape has been set.
Raises:
ValueError: if `expected_shape` is not a `Tensor` and the shape of `value`
is known and is not equal to `expected_shape`.
"""
assert isinstance(value, ops.Tensor)
if isinstance(expected_shape, tensor_shape.TensorShape):
expected_shape = expected_shape.as_list()
if isinstance(expected_shape, ops.Tensor):
expected_shape_value = tensor_util.constant_value(expected_shape)
if expected_shape_value is not None:
expected_shape = [int(d) for d in expected_shape_value]
if isinstance(expected_shape, ops.Tensor):
value = _check_rank(value, array_ops.size(expected_shape))
else:
value = _check_rank(value, len(expected_shape))
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.reduce_all(math_ops.equal(expected_shape, array_ops.shape(
value))), [string_ops.string_join([
"Shape of tensor %s should be: " % value.name,
string_ops.as_string(expected_shape), ", shape received: ",
string_ops.as_string(array_ops.shape(value))
])])
]):
new_value = array_ops.identity(value, name="shape_checked")
if not isinstance(expected_shape, ops.Tensor):
try:
new_value.set_shape(new_value.get_shape().merge_with(expected_shape))
except ValueError as e:
raise ValueError("Shape check failed for %s: %s"
% (value.name, str(e)))
return new_value