utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:pgsm 作者: aroth85 项目源码 文件源码
def held_out_log_predicitive(clustering, dist, partition_prior, test_data, train_data, per_point=False):
    clustering = relabel_clustering(clustering)

    block_params = []

    log_cluster_prior = []

    block_ids = sorted(np.unique(clustering))

    for z in block_ids:
        params = dist.create_params_from_data(train_data[clustering == z])

        block_params.append(params)

        log_cluster_prior.append(partition_prior.log_tau_2_diff(params.N))

    num_blocks = len(block_ids)

    block_params.append(dist.create_params())

    log_cluster_prior.append(partition_prior.log_tau_1_diff(num_blocks))

    log_cluster_prior = np.array(log_cluster_prior)

    log_cluster_prior = log_normalize(log_cluster_prior)

    log_p = np.zeros((test_data.shape[0], len(log_cluster_prior)))

    for z, (w, params) in enumerate(zip(log_cluster_prior, block_params)):
        log_p[:, z] = w + dist.log_predictive_likelihood_bulk(test_data, params)

    if per_point:
        return log_sum_exp(log_p, axis=1)

    else:
        return np.sum(log_sum_exp(log_p, axis=1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号