def held_out_log_predicitive(clustering, dist, partition_prior, test_data, train_data, per_point=False):
clustering = relabel_clustering(clustering)
block_params = []
log_cluster_prior = []
block_ids = sorted(np.unique(clustering))
for z in block_ids:
params = dist.create_params_from_data(train_data[clustering == z])
block_params.append(params)
log_cluster_prior.append(partition_prior.log_tau_2_diff(params.N))
num_blocks = len(block_ids)
block_params.append(dist.create_params())
log_cluster_prior.append(partition_prior.log_tau_1_diff(num_blocks))
log_cluster_prior = np.array(log_cluster_prior)
log_cluster_prior = log_normalize(log_cluster_prior)
log_p = np.zeros((test_data.shape[0], len(log_cluster_prior)))
for z, (w, params) in enumerate(zip(log_cluster_prior, block_params)):
log_p[:, z] = w + dist.log_predictive_likelihood_bulk(test_data, params)
if per_point:
return log_sum_exp(log_p, axis=1)
else:
return np.sum(log_sum_exp(log_p, axis=1))
评论列表
文章目录