loc2lang_withpi.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:geomdn 作者: afshinrahimi 项目源码 文件源码
def pred_sharedparams_sym(self, mus, sigmas, corxy, pis, prediction_method='mixture'):
        '''
        select mus that maximize \sum_{pi_i * prob_i(mu)} if prediction_method is mixture
        else
        select the component with highest pi if prediction_method is pi.
        '''
        if prediction_method == 'mixture':
            X = mus[:, np.newaxis, :]
            diff = X - mus
            diffprod = T.prod(diff, axis=-1)
            sigmainvs = 1.0 / sigmas
            sigmainvprods = sigmainvs[:, 0] * sigmainvs[:, 1]
            sigmas2 = sigmas ** 2
            corxy2 = corxy **2
            diff2 = diff ** 2
            diffsigma = diff2 / sigmas2
            diffsigmanorm = T.sum(diffsigma, axis=-1)
            z = diffsigmanorm - 2 * corxy * diffprod * sigmainvprods
            oneminuscorxy2inv = 1.0 / (1.0 - corxy2)
            term = -0.5 * z * oneminuscorxy2inv
            expterm = T.exp(term)
            probs = (0.5 / np.pi) * sigmainvprods * T.sqrt(oneminuscorxy2inv) * expterm
            piprobs = pis[:, np.newaxis, :] * probs
            piprobsum = T.sum(piprobs, axis=-1)
            preds = T.argmax(piprobsum, axis=1)
            selected_mus = mus[preds, :]

            return selected_mus
        elif prediction_method == 'pi':
            logging.info('only pis are used for prediction')
            preds = T.argmax(pis, axis=1)
            selected_mus = mus[preds, :]      
            return selected_mus
        else:
            raise('%s is not a valid prediction method' %prediction_method)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号