def precompute_marginals(self):
sys.stderr.write('Precomputing marginals...\n')
self._pdfs = [None] * self._num_instances
# precomputing all possible marginals
for i in xrange(self._num_instances):
mean = self._corrected_means[i]
cov = self._corrected_covs[i]
self._pdfs[i] = [None] * (2 ** mean.shape[0])
for marginal_pattern in itertools.product([False, True], repeat=mean.shape[0]):
marginal_length = marginal_pattern.count(True)
if marginal_length == 0:
continue
m = np.array(marginal_pattern)
marginal_mean = mean[m]
mm = m[:, np.newaxis]
marginal_cov = cov[np.dot(mm, mm.transpose())].reshape((marginal_length, marginal_length))
self._pdfs[i][hash_bool_array(m)] = multivariate_normal(mean=marginal_mean, cov=marginal_cov)
评论列表
文章目录