def distance_as_discrepancy(dist, *summaries, observed):
"""Evaluate a distance function with signature `dist(summaries, observed)` in ELFI."""
summaries = np.column_stack(summaries)
# Ensure observed are 2d
observed = np.concatenate([np.atleast_2d(o) for o in observed], axis=1)
try:
d = dist(summaries, observed)
except ValueError as e:
raise ValueError('Incompatible data shape for the distance node. Please check '
'summary (XA) and observed (XB) output data dimensions. They '
'have to be at most 2d. Especially ensure that summary nodes '
'outputs 2d data even with batch_size=1. Original error message '
'was: {}'.format(e))
if d.ndim == 2 and d.shape[1] == 1:
d = d.reshape(-1)
return d
评论列表
文章目录