def __init__(self, N, V, tree_prior, config):
"""Initialize a model with an empty subsample.
Args:
N (int): Number of rows in the dataset.
V (int): Number of columns (features) in the dataset.
tree_prior: A [K]-shaped numpy array of prior edge log odds, where
K is the number of edges in the complete graph on V vertices.
config: A global config dict.
"""
assert isinstance(N, int)
assert isinstance(V, int)
assert isinstance(tree_prior, np.ndarray)
assert isinstance(config, dict)
K = V * (V - 1) // 2 # Number of edges in complete graph.
assert V <= 32768, 'Invalid # features > 32768: {}'.format(V)
assert tree_prior.shape == (K, )
assert tree_prior.dtype == np.float32
self._config = config.copy()
self._num_rows = N
self._tree_prior = tree_prior
self._tree = TreeStructure(V)
assert self._tree.num_vertices == V
self._program = make_propagation_program(self._tree.tree_grid)
self._added_rows = set()
评论列表
文章目录