def _transpose_batch_time(x):
"""Transpose the batch and time dimensions of a Tensor.
Retains as much of the static shape information as possible.
Args:
x: A tensor of rank 2 or higher.
Returns:
x transposed along the first two dimensions.
Raises:
ValueError: if `x` is rank 1 or lower.
"""
x_static_shape = x.get_shape()
if x_static_shape.ndims is not None and x_static_shape.ndims < 2:
raise ValueError(
"Expected input tensor %s to have rank at least 2, but saw shape: %s" %
(x, x_static_shape))
x_rank = tf.rank(x)
x_t = tf.transpose(
x, tf.concat(
([1, 0], tf.range(2, x_rank)), axis=0))
x_t.set_shape(
[x_static_shape[1].value, x_static_shape[0].value] + x_static_shape[2:])
return x_t
评论列表
文章目录