def pad_or_clip_tensor(t, length):
"""Pad or clip the input tensor along the first dimension.
Args:
t: the input tensor, assuming the rank is at least 1.
length: a tensor of shape [1] or an integer, indicating the first dimension
of the input tensor t after processing.
Returns:
processed_t: the processed tensor, whose first dimension is length. If the
length is an integer, the first dimension of the processed tensor is set
to length statically.
"""
processed_t = tf.cond(
tf.greater(tf.shape(t)[0], length),
lambda: clip_tensor(t, length),
lambda: pad_tensor(t, length))
if not _is_tensor(length):
processed_t = _set_dim_0(processed_t, length)
return processed_t
评论列表
文章目录