train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DL2W 作者: gauravmm 项目源码 文件源码
def train(tfrecord_file, train_dir, batch_size, num_epochs):
    _, vectors, labels = data_loader.inputs(
        [tfrecord_file], batch_size=batch_size,
        num_threads=16, capacity=batch_size*4,
        min_after_dequeue=batch_size*2,
        num_epochs=num_epochs, is_training=True)

    loss = model.loss(vectors, labels)

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Create training op with dependencies on update ops for batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(learning_rate=0.001). \
            minimize(loss, global_step=global_step)

    # Create training supervisor to manage model logging and saving
    sv = tf.train.Supervisor(logdir=train_dir, global_step=global_step,
                             save_summaries_secs=60, save_model_secs=600)

    with sv.managed_session() as sess:
        while not sv.should_stop():
            _, loss_out, step_out = sess.run([train_op, loss, global_step])

            if step_out % 100 == 0:
                print('Step {}: Loss {}'.format(step_out, loss_out))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号