def get_power_spectral_density_matrix(observation, mask=None):
"""
Calculates the weighted power spectral density matrix.
This does not yet work with more than one target mask.
:param observation: Complex observations with shape (bins, sensors, frames)
:param mask: Masks with shape (bins, frames) or (bins, 1, frames)
:return: PSD matrix with shape (bins, sensors, sensors)
"""
bins, sensors, frames = observation.shape
if mask is None:
mask = np.ones((bins, frames))
if mask.ndim == 2:
mask = mask[:, np.newaxis, :]
normalization = np.maximum(np.sum(mask, axis=-1, keepdims=True), 1e-6)
psd = np.einsum('...dt,...et->...de', mask * observation,
observation.conj())
psd /= normalization
return psd
评论列表
文章目录