def load_data(ds_name, use_node_labels):
node2graph = {}
Gs = []
with open("../datasets/%s/%s_graph_indicator.txt"%(ds_name,ds_name), "r") as f:
c = 1
for line in f:
node2graph[c] = int(line[:-1])
if not node2graph[c] == len(Gs):
Gs.append(nx.Graph())
Gs[-1].add_node(c)
c += 1
with open("../datasets/%s/%s_A.txt"%(ds_name,ds_name), "r") as f:
for line in f:
edge = line[:-1].split(",")
edge[1] = edge[1].replace(" ", "")
Gs[node2graph[int(edge[0])]-1].add_edge(int(edge[0]), int(edge[1]))
if use_node_labels:
with open("../datasets/%s/%s_node_labels.txt"%(ds_name,ds_name), "r") as f:
c = 1
for line in f:
node_label = int(line[:-1])
Gs[node2graph[c]-1].node[c]['label'] = node_label
c += 1
labels = []
with open("../datasets/%s/%s_graph_labels.txt"%(ds_name,ds_name), "r") as f:
for line in f:
labels.append(int(line[:-1]))
labels = np.array(labels, dtype = np.float)
return Gs, labels
评论列表
文章目录