train.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow-grid-lstm 作者: philipperemy 项目源码 文件源码
def train(args):
    data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
    args.vocab_size = data_loader.vocab_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(args, f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f:
        pickle.dump((data_loader.chars, data_loader.vocab), f)

    model = Model(args)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(tf.global_variables())
        train_loss_iterations = {'iteration': [], 'epoch': [], 'train_loss': [], 'val_loss': []}

        for e in range(args.num_epochs):
            sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
            data_loader.reset_batch_pointer()
            state = sess.run(model.initial_state)
            for b in range(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.initial_state: state}
                train_loss, state, _ = sess.run([model.cost, model.final_state, model.train_op], feed)
                end = time.time()
                batch_idx = e * data_loader.num_batches + b
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(batch_idx,
                              args.num_epochs * data_loader.num_batches,
                              e, train_loss, end - start))
                train_loss_iterations['iteration'].append(batch_idx)
                train_loss_iterations['epoch'].append(e)
                train_loss_iterations['train_loss'].append(train_loss)

                if batch_idx % args.save_every == 0:

                    # evaluate
                    state_val = sess.run(model.initial_state)
                    avg_val_loss = 0
                    for x_val, y_val in data_loader.val_batches:
                        feed_val = {model.input_data: x_val, model.targets: y_val, model.initial_state: state_val}
                        val_loss, state_val, _ = sess.run([model.cost, model.final_state, model.train_op], feed_val)
                        avg_val_loss += val_loss / len(list(data_loader.val_batches))
                    print('val_loss: {:.3f}'.format(avg_val_loss))
                    train_loss_iterations['val_loss'].append(avg_val_loss)

                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=e * data_loader.num_batches + b)
                    print("model saved to {}".format(checkpoint_path))
                else:
                    train_loss_iterations['val_loss'].append(None)

            pd.DataFrame(data=train_loss_iterations,
                         columns=train_loss_iterations.keys()).to_csv(os.path.join(args.save_dir, 'log.csv'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号