utils.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:cnn-graph-classification 作者: giannisnik 项目源码 文件源码
def load_data(ds_name, use_node_labels):
    node2graph = {}
    Gs = []

    with open("../datasets/%s/%s_graph_indicator.txt"%(ds_name,ds_name), "r") as f:
        c = 1
        for line in f:
            node2graph[c] = int(line[:-1])
            if not node2graph[c] == len(Gs):
                Gs.append(nx.Graph())
            Gs[-1].add_node(c)
            c += 1

    with open("../datasets/%s/%s_A.txt"%(ds_name,ds_name), "r") as f:
        for line in f:
            edge = line[:-1].split(",")
            edge[1] = edge[1].replace(" ", "")
            Gs[node2graph[int(edge[0])]-1].add_edge(int(edge[0]), int(edge[1]))

    if use_node_labels:
        with open("../datasets/%s/%s_node_labels.txt"%(ds_name,ds_name), "r") as f:
            c = 1
            for line in f:
                node_label = int(line[:-1])
                Gs[node2graph[c]-1].node[c]['label'] = node_label
                c += 1

    labels = []
    with open("../datasets/%s/%s_graph_labels.txt"%(ds_name,ds_name), "r") as f:
        for line in f:
            labels.append(int(line[:-1]))

    labels  = np.array(labels, dtype = np.float)
    return Gs, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号