main_all_sites.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:gcn_metric_learning 作者: sk1712 项目源码 文件源码
def prepare_pairs(X, y, site, indices):
    """ Prepare the graph pairs before feeding them to the network """
    N, M, F = X.shape
    n_pairs = int(len(indices) * (len(indices) - 1) / 2)
    triu_pairs = np.triu_indices(len(indices), 1)

    X_pairs = np.ones((n_pairs, M, F, 2))
    X_pairs[:, :, :, 0] = X[indices][triu_pairs[0]]
    X_pairs[:, :, :, 1] = X[indices][triu_pairs[1]]

    site_pairs = np.ones(int(n_pairs))
    site_pairs[site[indices][triu_pairs[0]] != site[indices][triu_pairs[1]]] = 0

    y_pairs = np.ones(int(n_pairs))
    y_pairs[y[indices][triu_pairs[0]] != y[indices][triu_pairs[1]]] = 0  # -1

    print(n_pairs)

    return X_pairs, y_pairs, site_pairs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号