def run_epoch(session, config, graph, iterator, ops=None,
summary_writer=None, summary_prefix=None, saver=None):
"""Runs the model on the given data."""
if not ops:
ops = []
def should_monitor(step):
return step and c['monitoring_frequency'] and (step + 1) % c['monitoring_frequency'] == 0
def should_save(step):
return step and c['saving_frequency'] and (step + 1) % c['saving_frequency'] == 0
# Shortcuts, ugly but still increase readability
c = config
g = graph
m = Monitor(summary_writer, summary_prefix)
while g['step_number'].eval() < FLAGS.task * c['next_worker_delay']:
pass
# Statistics
for step, (inputs, lengths) in enumerate(iterator):
# Define what we feed
feed_dict = {g['inputs']: inputs,
g['input_lengths']: lengths}
# Define what we fetch
fetch = dict(g['observed'])
fetch['total_neg_loglikelihood'] = g['total_neg_loglikelihood']
fetch['total_correct'] = g['total_correct']
fetch['_ops'] = ops
# RUN!!!
r = session.run(fetch, feed_dict)
# Update the monitor accumulators
m.total_neg_loglikelihood += r['total_neg_loglikelihood']
m.total_correct += r['total_correct']
# We do not predict the first words, that's why
# batch_size has to subtracted from the total
m.steps += 1
m.words += sum(lengths) - c['batch_size']
m.sentences += c['batch_size']
m.words_including_padding += c['batch_size'] * len(inputs[0])
m.step_number = g['step_number'].eval()
m.learning_rate = float(g['learning_rate'].eval())
for key in g['observed']:
m.observed[key] += r[key]
if should_monitor(step):
tf.logging.info('monitor')
result = m.monitor()
if saver and should_save(step):
print("saved")
saver.save(session, os.path.join(FLAGS.train_path, 'model'))
if not should_monitor(step):
result = m.monitor()
if saver:
saver.save(session, os.path.join(FLAGS.train_path, 'model'))
return result
评论列表
文章目录