utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def broadcast(tensor, target_tensor):
    """Broadcast a tensor to match the shape of a target tensor.

    Args:
        tensor (Tensor): tensor to be tiled
        target_tensor (Tensor): tensor whose shape is to be matched
    """
    rank = lambda t: t.get_shape().ndims
    assert rank(tensor) == rank(target_tensor)  # TODO: assert that tensors have no overlapping non-unity dimensions

    orig_shape = tf.shape(tensor)
    target_shape = tf.shape(target_tensor)

    # if dim == 1, set it to target_dim
    # else, set it to 1
    tiling_factor = tf.select(tf.equal(orig_shape, 1), target_shape, tf.ones([rank(tensor)], dtype=tf.int32))
    broadcasted = tf.tile(tensor, tiling_factor)

    # Add static shape information
    broadcasted.set_shape(target_tensor.get_shape())

    return broadcasted
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号