def broadcast(tensor, target_tensor):
"""Broadcast a tensor to match the shape of a target tensor.
Args:
tensor (Tensor): tensor to be tiled
target_tensor (Tensor): tensor whose shape is to be matched
"""
rank = lambda t: t.get_shape().ndims
assert rank(tensor) == rank(target_tensor) # TODO: assert that tensors have no overlapping non-unity dimensions
orig_shape = tf.shape(tensor)
target_shape = tf.shape(target_tensor)
# if dim == 1, set it to target_dim
# else, set it to 1
tiling_factor = tf.select(tf.equal(orig_shape, 1), target_shape, tf.ones([rank(tensor)], dtype=tf.int32))
broadcasted = tf.tile(tensor, tiling_factor)
# Add static shape information
broadcasted.set_shape(target_tensor.get_shape())
return broadcasted
评论列表
文章目录