train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def main(_):
    """Does all the things that we need to do:
        - gets data
        - sets up the graph for inference and sampling
        - gets training ops etc.
        - initialise or reload the variables.
        - train until it's time to go.
    """
    _start_msg('getting data')
    data = util.get_data(FLAGS.batch_size, FLAGS.sequence_length,
                         FLAGS.dataset, FLAGS.embedding_size)
    _end_msg('got data')

    _start_msg('getting forward model')
    rnn_model = get_forward(data)
    _end_msg('got forward model')

    _start_msg('getting train ops')
    # TODO(pfcm): be more flexible with this
    global_step = tf.Variable(0, name='global_step', trainable=False)
    saver = tf.train.Saver(tf.all_variables(),
                           max_to_keep=1)
    loss_op, train_op = minimise_xent(
        rnn_model['inference']['logits'], data['placeholders']['targets'],
        global_step=global_step)
    _end_msg('got train ops')
    do_training(data, rnn_model, loss_op, train_op, saver=saver)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号