def expand_dims_for_broadcast(low_tensor, high_tensor):
"""Expand the dimensions of a lower-rank tensor, so that its rank matches that of a higher-rank tensor.
This makes it possible to perform broadcast operations between low_tensor and high_tensor.
Args:
low_tensor (Tensor): lower-rank Tensor with shape [s_0, ..., s_p]
high_tensor (Tensor): higher-rank Tensor with shape [s_0, ..., s_p, ..., s_n]
Note that the shape of low_tensor must be a prefix of the shape of high_tensor.
Returns:
Tensor: the lower-rank tensor, but with shape expanded to be [s_0, ..., s_p, 1, 1, ..., 1]
"""
orig_shape = tf.shape(low_tensor)
orig_rank = tf.rank(low_tensor)
target_rank = tf.rank(high_tensor)
# assert that shapes are compatible
assert_op = assert_broadcastable(low_tensor, high_tensor)
with tf.control_dependencies([assert_op]):
pad_shape = tf.tile([1], [target_rank - orig_rank])
new_shape = tf.concat(0, [orig_shape, pad_shape])
result = tf.reshape(low_tensor, new_shape)
# add static shape information
high_shape_static = high_tensor.get_shape()
low_shape_static = low_tensor.get_shape()
extra_rank = high_shape_static.ndims - low_shape_static.ndims
result_dims = list(low_shape_static.dims) + [tf.Dimension(1)] * extra_rank
result_shape = tf.TensorShape(result_dims)
result.set_shape(result_shape)
return result
评论列表
文章目录