train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:nmt 作者: tensorflow 项目源码 文件源码
def check_stats(stats, global_step, steps_per_stats, hparams, log_f):
  """Print statistics and also check for overflow."""
  # Print statistics for the previous epoch.
  avg_step_time = stats["step_time"] / steps_per_stats
  avg_grad_norm = stats["grad_norm"] / steps_per_stats
  train_ppl = utils.safe_exp(
      stats["loss"] / stats["predict_count"])
  speed = stats["total_count"] / (1000 * stats["step_time"])
  utils.print_out(
      "  global step %d lr %g "
      "step-time %.2fs wps %.2fK ppl %.2f gN %.2f %s" %
      (global_step, stats["learning_rate"],
       avg_step_time, speed, train_ppl, avg_grad_norm,
       _get_best_results(hparams)),
      log_f)

  # Check for overflow
  is_overflow = False
  if math.isnan(train_ppl) or math.isinf(train_ppl) or train_ppl > 1e20:
    utils.print_out("  step %d overflow, stop early" % global_step, log_f)
    is_overflow = True

  return is_overflow
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号