def _accumulate_sufficient_statistics(self, stats, obs, framelogprob,
posteriors, fwdlattice, bwdlattice):
super(GaussianHMM, self)._accumulate_sufficient_statistics(
stats, obs, framelogprob, posteriors, fwdlattice, bwdlattice)
if 'm' in self.params or 'c' in self.params:
stats['post'] += posteriors.sum(axis=0)
stats['obs'] += np.dot(posteriors.T, obs)
if 'c' in self.params:
if self.covariance_type in ('spherical', 'diag'):
stats['obs**2'] += np.dot(posteriors.T, obs ** 2)
elif self.covariance_type in ('tied', 'full'):
# posteriors: (nt, nc); obs: (nt, nf); obs: (nt, nf)
# -> (nc, nf, nf)
stats['obs*obs.T'] += np.einsum(
'ij,ik,il->jkl', posteriors, obs, obs)
评论列表
文章目录