matrix.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:paysage 作者: drckf 项目源码 文件源码
def pdist(x: T.FloatTensor, y: T.FloatTensor) -> T.FloatTensor:
    """
    Compute the pairwise distance matrix between the rows of x and y.

    Args:
        x (tensor (num_samples_1, num_units))
        y (tensor (num_samples_2, num_units))

    Returns:
        tensor (num_samples_1, num_samples_2)

    """
    inner = dot(x, transpose(y))
    x_mag = norm(x, axis=1) ** 2
    y_mag = norm(y, axis=1) ** 2
    squared = add(unsqueeze(y_mag, axis=0), add(unsqueeze(x_mag, axis=1), -2*inner))
    return torch.sqrt(clip(squared, a_min=0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号