def set_shape(tensor, shape):
""" This function will filling the missing shape information
of given tensor
"""
if not is_tensor(tensor):
raise ValueError('tensor must be instance of `Tensor`.')
# ====== Test ====== #
ndims = tensor.get_shape().ndims
shape = as_tuple(shape)
if ndims != len(shape):
raise ValueError("The tensor has %d dimensions, but the given shape "
"has %d dimension." % (ndims, len(shape)))
# ====== DO it ====== #
old_shape = tensor.get_shape()
new_shape = []
for old, new in zip(old_shape, shape):
old_value = old.value
if isinstance(new, tf.Dimension):
new = new.value
# matching old and new values
if old_value is not None and new is not None:
if old_value != new:
raise ValueError("Known shape information mismatch, from tensorflow"
":%s, and given shape:%s." %
(str(old_shape.as_list()), str(shape)))
else:
new_shape.append(old_value)
elif old_value is None and new is not None:
new_shape.append(new)
elif old_value is not None and new is None:
new_shape.append(old_value)
elif old is None and new is None:
new_shape.append(old)
else:
new_shape.append(None)
tensor.set_shape(new_shape)
return tensor
# ===========================================================================
# VALUE MANIPULATION
# ===========================================================================
评论列表
文章目录