utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:elfi 作者: elfi-dev 项目源码 文件源码
def distance_as_discrepancy(dist, *summaries, observed):
    """Evaluate a distance function with signature `dist(summaries, observed)` in ELFI."""
    summaries = np.column_stack(summaries)
    # Ensure observed are 2d
    observed = np.concatenate([np.atleast_2d(o) for o in observed], axis=1)
    try:
        d = dist(summaries, observed)
    except ValueError as e:
        raise ValueError('Incompatible data shape for the distance node. Please check '
                         'summary (XA) and observed (XB) output data dimensions. They '
                         'have to be at most 2d. Especially ensure that summary nodes '
                         'outputs 2d data even with batch_size=1. Original error message '
                         'was: {}'.format(e))
    if d.ndim == 2 and d.shape[1] == 1:
        d = d.reshape(-1)
    return d
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号