def calc_prob_same_diff(self, data_pairs, return_log=True, norm_probs=True,
return_prob='same'):
assert data_pairs.shape[-2] == 2
assert isinstance(return_log, bool)
assert return_prob == 'same' or return_prob == 'diff'
log_ps_diff = self.calc_probs_diff(data_pairs)
log_ps_same = self.calc_probs_same(data_pairs)
assert log_ps_diff.shape[-2] == log_ps_diff.shape[-1]
assert log_ps_diff.shape[-1] == log_ps_same.shape[-1]
log_prob_diff = logsumexp(log_ps_diff, axis=(-1, -2))
log_prob_same = logsumexp(log_ps_same, axis=-1) + np.log(6)
# Since there are 42 "different probabilities" and 7 "same".
# Multiplying by six makes the prior 50/50 on the same/diff task.
if return_prob == 'same':
log_probs = log_prob_same
else:
log_probs = log_prob_diff
if norm_probs is True:
norms = logsumexp([log_prob_diff, log_prob_same], axis=0)
log_probs = log_probs - norms
if return_log is True:
return log_probs
else:
return np.exp(log_probs)
评论列表
文章目录