training.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def treegauss_remove_row(
        data_row,
        tree_grid,
        latent_row,
        vert_ss,
        edge_ss,
        feat_ss, ):
    # Update sufficient statistics.
    for v in range(latent_row.shape[0]):
        z = latent_row[v, :]
        vert_ss[v, :, :] -= np.outer(z, z)
    for e in range(tree_grid.shape[1]):
        z1 = latent_row[tree_grid[1, e], :]
        z2 = latent_row[tree_grid[2, e], :]
        edge_ss[e, :, :] -= np.outer(z1, z2)
    for v, x in enumerate(data_row):
        if np.isnan(x):
            continue
        z = latent_row[v, :]
        feat_ss[v] -= 1
        feat_ss[v, 1] -= x
        feat_ss[v, 2:] -= x * z  # TODO Use central covariance.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号