def broadcast_against(tensor, against_expr):
"""Adds trailing dimensions to mask to enable broadcasting against data
:param tensor: tensor to be broadcasted
:param against_expr: tensor will be broadcasted against it
:return: mask expr with tf.rank(mask) == tf.rank(data)
"""
def cond(data, tensor):
return tf.less(tf.rank(tensor), tf.rank(data))
def body(data, tensor):
return data, tf.expand_dims(tensor, -1)
shape_invariants = [against_expr.get_shape(), tf.TensorShape(None)]
_, tensor = tf.while_loop(cond, body, [against_expr, tensor], shape_invariants)
return tensor
评论列表
文章目录