def __init__(self, gmm, swap=False, diff=False):
assert gmm.covariance_type == "full"
# D: static + delta dim
D = gmm.means_.shape[1] // 2
self.num_mixtures = gmm.means_.shape[0]
self.weights = gmm.weights_
# Split source and target parameters from joint GMM
self.src_means = gmm.means_[:, :D]
self.tgt_means = gmm.means_[:, D:]
self.covarXX = gmm.covariances_[:, :D, :D]
self.covarXY = gmm.covariances_[:, :D, D:]
self.covarYX = gmm.covariances_[:, D:, :D]
self.covarYY = gmm.covariances_[:, D:, D:]
if diff:
self.tgt_means = self.tgt_means - self.src_means
self.covarYY = self.covarXX + self.covarYY - self.covarXY - self.covarYX
self.covarXY = self.covarXY - self.covarXX
self.covarYX = self.covarXY.transpose(0, 2, 1)
# swap src and target parameters
if swap:
self.tgt_means, self.src_means = self.src_means, self.tgt_means
self.covarYY, self.covarXX = self.covarXX, self.covarYY
self.covarYX, self.covarXY = self.covarXY, self.covarYX
# p(x), which is used to compute posterior prob. for a given source
# spectral feature in mapping stage.
self.px = GaussianMixture(
n_components=self.num_mixtures, covariance_type="full")
self.px.means_ = self.src_means
self.px.covariances_ = self.covarXX
self.px.weights_ = self.weights
self.px.precisions_cholesky_ = _compute_precision_cholesky(
self.px.covariances_, "full")
评论列表
文章目录