def apply_sdw_mwf(mix, target_psd_matrix, noise_psd_matrix, mu=1, corr=None):
"""
Apply speech distortion weighted MWF: h = Tpsd * e1 / (Tpsd + mu*Npsd)
:param mix: the signal complex FFT
:param target_psd_matrix (bins, sensors, sensors)
:param noise_psd_matrix
:param mu: the lagrange factor
:return
"""
bins, sensors, frames = mix.shape
ref_vector = np.zeros((sensors,1), dtype=np.float)
if corr is None:
ref_ch = 0
else: # choose the channel with highest correlation with the others
corr=corr.tolist()
while len(corr) > sensors:
corr.remove(np.min(corr))
ref_ch=np.argmax(corr)
ref_vector[ref_ch,0]=1
mwf_vector = solve(target_psd_matrix + mu*noise_psd_matrix, target_psd_matrix[:,:,ref_ch])
return np.einsum('...a,...at->...t', mwf_vector.conj(), mix)
评论列表
文章目录