similarity.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sef 作者: passalis 项目源码 文件源码
def fast_heat_similarity_matrix(X, sigma):
    """
    PyTorch based similarity calculation
    :param X: the matrix with the data
    :param sigma: scaling factor
    :return: the similarity matrix
    """
    use_gpu = False
    # Use GPU if available
    if torch.cuda.device_count() > 0:
        use_gpu = True

    X = Variable(torch.from_numpy(np.float32(X)))
    sigma = Variable(torch.from_numpy(np.float32([sigma])))
    if use_gpu:
        X, sigma = X.cuda(), sigma.cuda()

    D = sym_heat_similarity_matrix(X, sigma)

    if use_gpu:
        D = D.cpu()

    return D.data.numpy()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号