lm.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:wip-lambada-lm 作者: brain-research 项目源码 文件源码
def run_epoch(session, config, graph, iterator, ops=None,
              summary_writer=None, summary_prefix=None, saver=None):
  """Runs the model on the given data."""
  if not ops:
    ops = []

  def should_monitor(step):
    return step and c['monitoring_frequency'] and (step + 1) % c['monitoring_frequency'] == 0
  def should_save(step):
    return step and c['saving_frequency'] and (step + 1) % c['saving_frequency'] == 0

  # Shortcuts, ugly but still increase readability
  c = config
  g = graph
  m = Monitor(summary_writer, summary_prefix)

  while g['step_number'].eval() < FLAGS.task * c['next_worker_delay']:
    pass

  # Statistics
  for step, (inputs, lengths) in enumerate(iterator):
    # Define what we feed
    feed_dict = {g['inputs']: inputs,
                 g['input_lengths']: lengths}

    # Define what we fetch
    fetch = dict(g['observed'])
    fetch['total_neg_loglikelihood'] = g['total_neg_loglikelihood']
    fetch['total_correct'] = g['total_correct']
    fetch['_ops'] = ops

    # RUN!!!
    r = session.run(fetch, feed_dict)

    # Update the monitor accumulators
    m.total_neg_loglikelihood += r['total_neg_loglikelihood']
    m.total_correct += r['total_correct']
    # We do not predict the first words, that's why
    # batch_size has to subtracted from the total
    m.steps += 1
    m.words += sum(lengths) - c['batch_size']
    m.sentences += c['batch_size']
    m.words_including_padding += c['batch_size'] * len(inputs[0])
    m.step_number = g['step_number'].eval()
    m.learning_rate = float(g['learning_rate'].eval())
    for key in g['observed']:
      m.observed[key] += r[key]

    if should_monitor(step):
      tf.logging.info('monitor')
      result = m.monitor()
    if saver and should_save(step):
      print("saved")
      saver.save(session, os.path.join(FLAGS.train_path, 'model'))

  if not should_monitor(step):
    result = m.monitor()
  if saver:
    saver.save(session, os.path.join(FLAGS.train_path, 'model'))
  return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号