def train(self):
"""Train a TreeCat model using subsample-annealed MCMC.
Returns:
A trained model as a dictionary with keys:
config: A global config dict.
tree: A TreeStructure instance with the learned latent
structure.
edge_logits: A [K]-shaped array of all edge logits.
suffstats: Sufficient statistics of features, vertices, and
edges and a ragged_index for the features array.
assignments: An [N, V]-shaped numpy array of latent cluster
ids for each cell in the dataset, where N be the number of
data rows and V is the number of features.
"""
model = TreeTrainer.train(self)
model['assignments'] = self._assignments
model['suffstats'] = {
'ragged_index': self._table.ragged_index,
'vert_ss': self._vert_ss,
'edge_ss': self._edge_ss,
'feat_ss': self._feat_ss,
'meas_ss': self._meas_ss,
}
return model
评论列表
文章目录