utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def expand_dims_for_broadcast(low_tensor, high_tensor):
    """Expand the dimensions of a lower-rank tensor, so that its rank matches that of a higher-rank tensor.

    This makes it possible to perform broadcast operations between low_tensor and high_tensor.

    Args:
        low_tensor (Tensor): lower-rank Tensor with shape [s_0, ..., s_p]
        high_tensor (Tensor): higher-rank Tensor with shape [s_0, ..., s_p, ..., s_n]

    Note that the shape of low_tensor must be a prefix of the shape of high_tensor.

    Returns:
        Tensor: the lower-rank tensor, but with shape expanded to be [s_0, ..., s_p, 1, 1, ..., 1]
    """
    orig_shape = tf.shape(low_tensor)
    orig_rank = tf.rank(low_tensor)
    target_rank = tf.rank(high_tensor)

    # assert that shapes are compatible
    assert_op = assert_broadcastable(low_tensor, high_tensor)

    with tf.control_dependencies([assert_op]):
        pad_shape = tf.tile([1], [target_rank - orig_rank])
        new_shape = tf.concat(0, [orig_shape, pad_shape])
        result = tf.reshape(low_tensor, new_shape)

    # add static shape information
    high_shape_static = high_tensor.get_shape()
    low_shape_static = low_tensor.get_shape()
    extra_rank = high_shape_static.ndims - low_shape_static.ndims

    result_dims = list(low_shape_static.dims) + [tf.Dimension(1)] * extra_rank
    result_shape = tf.TensorShape(result_dims)
    result.set_shape(result_shape)

    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号