def nll_loss_sharedparams(self, mus, sigmas, corxy, pis, y_true):
mus_ex = mus[np.newaxis, :, :]
X = y_true[:, np.newaxis, :]
diff = X - mus_ex
diffprod = T.prod(diff, axis=-1)
corxy2 = corxy **2
diff2 = diff ** 2
sigmas2 = sigmas ** 2
sigmainvs = 1.0 / sigmas
sigmainvprods = sigmainvs[:, 0] * sigmainvs[:, 1]
diffsigma = diff2 / sigmas2
diffsigmanorm = T.sum(diffsigma, axis=-1)
z = diffsigmanorm - 2 * corxy * diffprod * sigmainvprods
oneminuscorxy2inv = 1.0 / (1.0 - corxy2)
expterm = -0.5 * z * oneminuscorxy2inv
new_exponent = T.log(0.5/np.pi) + T.log(sigmainvprods) + T.log(np.sqrt(oneminuscorxy2inv)) + expterm + T.log(pis)
max_exponent = T.max(new_exponent ,axis=1, keepdims=True)
mod_exponent = new_exponent - max_exponent
gauss_mix = T.sum(T.exp(mod_exponent),axis=1)
log_gauss = max_exponent + T.log(gauss_mix)
loss = -T.mean(log_gauss)
return loss
评论列表
文章目录