discriminator.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:plda 作者: RaviSoji 项目源码 文件源码
def calc_prob_same_diff(self, data_pairs, return_log=True, norm_probs=True,
                            return_prob='same'):
        assert data_pairs.shape[-2] == 2
        assert isinstance(return_log, bool)
        assert return_prob == 'same' or return_prob == 'diff'

        log_ps_diff = self.calc_probs_diff(data_pairs)
        log_ps_same = self.calc_probs_same(data_pairs)
        assert log_ps_diff.shape[-2] == log_ps_diff.shape[-1]
        assert log_ps_diff.shape[-1] == log_ps_same.shape[-1]

        log_prob_diff = logsumexp(log_ps_diff, axis=(-1, -2))
        log_prob_same = logsumexp(log_ps_same, axis=-1) + np.log(6)
        # Since there are 42 "different probabilities" and 7 "same".
        # Multiplying by six makes the prior 50/50 on the same/diff task.

        if return_prob == 'same':
            log_probs = log_prob_same
        else:
            log_probs = log_prob_diff

        if norm_probs is True:
            norms = logsumexp([log_prob_diff, log_prob_same], axis=0)
            log_probs = log_probs - norms

        if return_log is True:
            return log_probs
        else:
            return np.exp(log_probs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号