ae.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def compute_mse_loss(x, xhat, hparams):
  """MSE loss function.

  Args:
    x: Input data tensor.
    xhat: Reconstruction tensor.
    hparams: Hyperparameters.

  Returns:
    total_loss: MSE loss scalar.
  """
  with tf.name_scope("Losses"):
    if hparams.raw_audio:
      total_loss = tf.reduce_mean((x - xhat)**2)
    else:
      # Magnitude
      m = x[:, :, :, 0] if hparams.cost_phase_mask else 1.0
      fm = utils.frequency_weighted_cost_mask(
          hparams.fw_loss_coeff,
          hz_flat=hparams.fw_loss_cutoff,
          n_fft=hparams.n_fft)
      mag_loss = tf.reduce_mean(fm * (x[:, :, :, 0] - xhat[:, :, :, 0])**2)
      if hparams.mag_only:
        total_loss = mag_loss
      else:
        # Phase
        if hparams.dphase:
          phase_loss = tf.reduce_mean(fm * m *
                                      (x[:, :, :, 1] - xhat[:, :, :, 1])**2)
        else:
          # Von Mises Distribution "Circular Normal"
          # Added constant to keep positive (Same Probability) range [0, 2]
          phase_loss = 1 - tf.reduce_mean(fm * m * tf.cos(
              (x[:, :, :, 1] - xhat[:, :, :, 1]) * np.pi))
        total_loss = mag_loss + hparams.phase_loss_coeff * phase_loss
        tf.summary.scalar("Loss/Mag", mag_loss)
        tf.summary.scalar("Loss/Phase", phase_loss)
    tf.summary.scalar("Loss/Total", total_loss)
  return total_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号