graph.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:ProtScan 作者: gianlucacorrado 项目源码 文件源码
def graph_preprocessor(graphs, which_set, bin_sites=None, max_dist=None,
                       random_state=1234, **params):
    """Preprocess graphs."""
    assert which_set == 'train' or which_set == 'test', \
        "which_set must be either 'train' or 'test'."

    if which_set == 'train':
        graphs = add_distance(graphs, bin_sites)
        graphs = split_iterator(graphs, **params)
        graphs = add_type(graphs, max_dist)
        return graphs
    elif which_set == 'test':
        graphs, graphs_ = tee(graphs)
        full_graphs = transform_dictionary(graphs_)
        graphs = split_iterator(graphs, **params)
        return full_graphs, graphs
    else:
        raise Exception("ERROR: unrecognized which_set type: %s" %
                        which_set)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号