tensor_ops.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def broadcast_against(tensor, against_expr):
    """Adds trailing dimensions to mask to enable broadcasting against data

    :param tensor: tensor to be broadcasted
    :param against_expr: tensor will be broadcasted against it
    :return: mask expr with tf.rank(mask) == tf.rank(data)
    """

    def cond(data, tensor):
        return tf.less(tf.rank(tensor), tf.rank(data))

    def body(data, tensor):
        return data, tf.expand_dims(tensor, -1)

    shape_invariants = [against_expr.get_shape(), tf.TensorShape(None)]
    _, tensor = tf.while_loop(cond, body, [against_expr, tensor], shape_invariants)
    return tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号