def compute_mse_loss(x, xhat, hparams):
"""MSE loss function.
Args:
x: Input data tensor.
xhat: Reconstruction tensor.
hparams: Hyperparameters.
Returns:
total_loss: MSE loss scalar.
"""
with tf.name_scope("Losses"):
if hparams.raw_audio:
total_loss = tf.reduce_mean((x - xhat)**2)
else:
# Magnitude
m = x[:, :, :, 0] if hparams.cost_phase_mask else 1.0
fm = utils.frequency_weighted_cost_mask(
hparams.fw_loss_coeff,
hz_flat=hparams.fw_loss_cutoff,
n_fft=hparams.n_fft)
mag_loss = tf.reduce_mean(fm * (x[:, :, :, 0] - xhat[:, :, :, 0])**2)
if hparams.mag_only:
total_loss = mag_loss
else:
# Phase
if hparams.dphase:
phase_loss = tf.reduce_mean(fm * m *
(x[:, :, :, 1] - xhat[:, :, :, 1])**2)
else:
# Von Mises Distribution "Circular Normal"
# Added constant to keep positive (Same Probability) range [0, 2]
phase_loss = 1 - tf.reduce_mean(fm * m * tf.cos(
(x[:, :, :, 1] - xhat[:, :, :, 1]) * np.pi))
total_loss = mag_loss + hparams.phase_loss_coeff * phase_loss
tf.summary.scalar("Loss/Mag", mag_loss)
tf.summary.scalar("Loss/Phase", phase_loss)
tf.summary.scalar("Loss/Total", total_loss)
return total_loss
评论列表
文章目录