def prepare_pairs(X, y, site, indices):
""" Prepare the graph pairs before feeding them to the network """
N, M, F = X.shape
n_pairs = int(len(indices) * (len(indices) - 1) / 2)
triu_pairs = np.triu_indices(len(indices), 1)
X_pairs = np.ones((n_pairs, M, F, 2))
X_pairs[:, :, :, 0] = X[indices][triu_pairs[0]]
X_pairs[:, :, :, 1] = X[indices][triu_pairs[1]]
site_pairs = np.ones(int(n_pairs))
site_pairs[site[indices][triu_pairs[0]] != site[indices][triu_pairs[1]]] = 0
y_pairs = np.ones(int(n_pairs))
y_pairs[y[indices][triu_pairs[0]] != y[indices][triu_pairs[1]]] = 0 # -1
print(n_pairs)
return X_pairs, y_pairs, site_pairs
评论列表
文章目录