def treegauss_remove_row(
data_row,
tree_grid,
latent_row,
vert_ss,
edge_ss,
feat_ss, ):
# Update sufficient statistics.
for v in range(latent_row.shape[0]):
z = latent_row[v, :]
vert_ss[v, :, :] -= np.outer(z, z)
for e in range(tree_grid.shape[1]):
z1 = latent_row[tree_grid[1, e], :]
z2 = latent_row[tree_grid[2, e], :]
edge_ss[e, :, :] -= np.outer(z1, z2)
for v, x in enumerate(data_row):
if np.isnan(x):
continue
z = latent_row[v, :]
feat_ss[v] -= 1
feat_ss[v, 1] -= x
feat_ss[v, 2:] -= x * z # TODO Use central covariance.
评论列表
文章目录