lang2loc_mdnshared.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:geomdn 作者: afshinrahimi 项目源码 文件源码
def pred_sharedparams(self, mus, sigmas, corxy, pis, prediction_method='mixture'):
        """
        Given a mixture of Gaussians infer a mu that maximizes the mixture.
        There are two modes:
        If prediction_method==mixture then predict one of the mus that maximizes
        \mathcal{P}(\boldsymbol{x}) = \sum_{k=1}^{K} \pi_k \mathcal{N}(\boldsymbol{x} \vert \boldsymbol{\mu_k}, \Sigma_k)

        If prediction_method==pi return the mu that has the largest pi.
        """

        if prediction_method == 'mixture':
            X = mus[:, np.newaxis, :]
            diff = X - mus
            diffprod = np.prod(diff, axis=-1)
            sigmainvs = 1.0 / sigmas
            sigmainvprods = sigmainvs[:, 0] * sigmainvs[:, 1]
            sigmas2 = sigmas ** 2
            corxy2 = corxy ** 2
            diff2 = diff ** 2
            diffsigma = diff2 / sigmas2
            diffsigmanorm = np.sum(diffsigma, axis=-1)
            z = diffsigmanorm - 2 * corxy * diffprod * sigmainvprods
            oneminuscorxy2inv = 1.0 / (1.0 - corxy2)
            term = -0.5 * z * oneminuscorxy2inv
            expterm = np.exp(term)
            probs = (0.5 / np.pi) * sigmainvprods * np.sqrt(oneminuscorxy2inv) * expterm
            piprobs = pis[:, np.newaxis, :] * probs
            piprobsum = np.sum(piprobs, axis=-1)
            preds = np.argmax(piprobsum, axis=1)
            selected_mus = mus[preds, :]
            return selected_mus

        elif prediction_method == 'pi':
            logging.info('only pis are used for prediction')
            preds = np.argmax(pis, axis=1)
            selected_mus = mus[preds, :]
            #selected_sigmas = sigmas[np.arange(sigmas.shape[0]), :, preds]
            #selected_corxy = corxy[np.arange(corxy.shape[0]),preds]
            #selected_pis = pis[np.arange(pis.shape[0]),preds]        
            return selected_mus
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号