matrix.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:paysage 作者: drckf 项目源码 文件源码
def repeat(tensor: T.FloatTensor, n: int) -> T.FloatTensor:
    """
    Repeat tensor n times along specified axis.

    Args:
        tensor: A vector (i.e., 1D tensor).
        n: The number of repeats.

    Returns:
        tensor: A vector created from many repeats of the input tensor.

    """
    # current implementation only works for vectors
    assert ndim(tensor) == 1
    return flatten(tensor.unsqueeze(1).repeat(1, n))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号