training.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def train(self):
        """Train a TreeCat model using subsample-annealed MCMC.

        Returns:
            A trained model as a dictionary with keys:
                config: A global config dict.
                tree: A TreeStructure instance with the learned latent
                    structure.
                edge_logits: A [K]-shaped array of all edge logits.
                suffstats: Sufficient statistics of features, vertices, and
                    edges and a ragged_index for the features array.
                assignments: An [N, V]-shaped numpy array of latent cluster
                    ids for each cell in the dataset, where N be the number of
                    data rows and V is the number of features.
        """
        model = TreeTrainer.train(self)
        model['assignments'] = self._assignments
        model['suffstats'] = {
            'ragged_index': self._table.ragged_index,
            'vert_ss': self._vert_ss,
            'edge_ss': self._edge_ss,
            'feat_ss': self._feat_ss,
            'meas_ss': self._meas_ss,
        }
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号