def clip_tensor(t, length):
"""Clips the input tensor along the first dimension up to the length.
Args:
t: the input tensor, assuming the rank is at least 1.
length: a tensor of shape [1] or an integer, indicating the first dimension
of the input tensor t after clipping, assuming length <= t.shape[0].
Returns:
clipped_t: the clipped tensor, whose first dimension is length. If the
length is an integer, the first dimension of clipped_t is set to length
statically.
"""
clipped_t = tf.gather(t, tf.range(length))
if not _is_tensor(length):
clipped_t = _set_dim_0(clipped_t, length)
return clipped_t
评论列表
文章目录