def transpose_batch_time(x):
"""Transpose the batch and time dimensions of a Tensor.
Retains as much of the static shape information as possible.
Args:
x: A tensor of rank 2 or higher.
Returns:
x transposed along the first two dimensions.
Raises:
ValueError: if `x` is rank 1 or lower.
"""
x_static_shape = x.get_shape()
if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
raise ValueError(
"Expected input tensor %s to have rank at least 2, but saw shape: %s" %
(x, x_static_shape))
x_rank = array_ops.rank(x)
x_t = array_ops.transpose(x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
x_t.set_shape(tf.tensor_shape.TensorShape([
x_static_shape[1].value, x_static_shape[0].value]).concatenate(x_static_shape[2:]))
return x_t
评论列表
文章目录