utils.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:polyaxon 作者: polyaxon 项目源码 文件源码
def transpose_batch_time(x):
    """Transpose the batch and time dimensions of a Tensor.

    Retains as much of the static shape information as possible.

    Args:
        x: A tensor of rank 2 or higher.

    Returns:
        x transposed along the first two dimensions.

    Raises:
        ValueError: if `x` is rank 1 or lower.
    """
    x_static_shape = x.get_shape()
    if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
        raise ValueError(
            "Expected input tensor %s to have rank at least 2, but saw shape: %s" %
            (x, x_static_shape))
    x_rank = array_ops.rank(x)
    x_t = array_ops.transpose(x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
    x_t.set_shape(tf.tensor_shape.TensorShape([
        x_static_shape[1].value, x_static_shape[0].value]).concatenate(x_static_shape[2:]))
    return x_t
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号