def graph_preprocessor(graphs, which_set, bin_sites=None, max_dist=None,
random_state=1234, **params):
"""Preprocess graphs."""
assert which_set == 'train' or which_set == 'test', \
"which_set must be either 'train' or 'test'."
if which_set == 'train':
graphs = add_distance(graphs, bin_sites)
graphs = split_iterator(graphs, **params)
graphs = add_type(graphs, max_dist)
return graphs
elif which_set == 'test':
graphs, graphs_ = tee(graphs)
full_graphs = transform_dictionary(graphs_)
graphs = split_iterator(graphs, **params)
return full_graphs, graphs
else:
raise Exception("ERROR: unrecognized which_set type: %s" %
which_set)
评论列表
文章目录