def pred_sharedparams_sym(self, mus, sigmas, corxy, pis, prediction_method='mixture'):
'''
select mus that maximize \sum_{pi_i * prob_i(mu)} if prediction_method is mixture
else
select the component with highest pi if prediction_method is pi.
'''
if prediction_method == 'mixture':
X = mus[:, np.newaxis, :]
diff = X - mus
diffprod = T.prod(diff, axis=-1)
sigmainvs = 1.0 / sigmas
sigmainvprods = sigmainvs[:, 0] * sigmainvs[:, 1]
sigmas2 = sigmas ** 2
corxy2 = corxy **2
diff2 = diff ** 2
diffsigma = diff2 / sigmas2
diffsigmanorm = T.sum(diffsigma, axis=-1)
z = diffsigmanorm - 2 * corxy * diffprod * sigmainvprods
oneminuscorxy2inv = 1.0 / (1.0 - corxy2)
term = -0.5 * z * oneminuscorxy2inv
expterm = T.exp(term)
probs = (0.5 / np.pi) * sigmainvprods * T.sqrt(oneminuscorxy2inv) * expterm
piprobs = pis[:, np.newaxis, :] * probs
piprobsum = T.sum(piprobs, axis=-1)
preds = T.argmax(piprobsum, axis=1)
selected_mus = mus[preds, :]
return selected_mus
elif prediction_method == 'pi':
logging.info('only pis are used for prediction')
preds = T.argmax(pis, axis=1)
selected_mus = mus[preds, :]
return selected_mus
else:
raise('%s is not a valid prediction method' %prediction_method)
评论列表
文章目录