def main(_):
"""Does all the things that we need to do:
- gets data
- sets up the graph for inference and sampling
- gets training ops etc.
- initialise or reload the variables.
- train until it's time to go.
"""
_start_msg('getting data')
data = util.get_data(FLAGS.batch_size, FLAGS.sequence_length,
FLAGS.dataset, FLAGS.embedding_size)
_end_msg('got data')
_start_msg('getting forward model')
rnn_model = get_forward(data)
_end_msg('got forward model')
_start_msg('getting train ops')
# TODO(pfcm): be more flexible with this
global_step = tf.Variable(0, name='global_step', trainable=False)
saver = tf.train.Saver(tf.all_variables(),
max_to_keep=1)
loss_op, train_op = minimise_xent(
rnn_model['inference']['logits'], data['placeholders']['targets'],
global_step=global_step)
_end_msg('got train ops')
do_training(data, rnn_model, loss_op, train_op, saver=saver)
评论列表
文章目录