def repeat(tensor: T.FloatTensor, n: int) -> T.FloatTensor:
"""
Repeat tensor n times along specified axis.
Args:
tensor: A vector (i.e., 1D tensor).
n: The number of repeats.
Returns:
tensor: A vector created from many repeats of the input tensor.
"""
# current implementation only works for vectors
assert ndim(tensor) == 1
return flatten(tensor.unsqueeze(1).repeat(1, n))
评论列表
文章目录