training.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def __init__(self, N, V, tree_prior, config):
        """Initialize a model with an empty subsample.

        Args:
            N (int): Number of rows in the dataset.
            V (int): Number of columns (features) in the dataset.
            tree_prior: A [K]-shaped numpy array of prior edge log odds, where
                K is the number of edges in the complete graph on V vertices.
            config: A global config dict.
        """
        assert isinstance(N, int)
        assert isinstance(V, int)
        assert isinstance(tree_prior, np.ndarray)
        assert isinstance(config, dict)
        K = V * (V - 1) // 2  # Number of edges in complete graph.
        assert V <= 32768, 'Invalid # features > 32768: {}'.format(V)
        assert tree_prior.shape == (K, )
        assert tree_prior.dtype == np.float32
        self._config = config.copy()
        self._num_rows = N
        self._tree_prior = tree_prior
        self._tree = TreeStructure(V)
        assert self._tree.num_vertices == V
        self._program = make_propagation_program(self._tree.tree_grid)
        self._added_rows = set()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号