train_models.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:CharacterGAN 作者: liamb315 项目源码 文件源码
def train_generator(args, load_recent=True):
    '''Train the generator via classical approach'''
    logging.debug('Batcher...')
    batcher   = Batcher(args.data_dir, args.batch_size, args.seq_length)

    logging.debug('Vocabulary...')
    with open(os.path.join(args.save_dir_gen, 'config.pkl'), 'w') as f:
        cPickle.dump(args, f)
    with open(os.path.join(args.save_dir_gen, 'real_beer_vocab.pkl'), 'w') as f:
        cPickle.dump((batcher.chars, batcher.vocab), f)

    logging.debug('Creating generator...')
    generator = Generator(args, is_training = True)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        if load_recent:
            ckpt = tf.train.get_checkpoint_state(args.save_dir_gen)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)

        for epoch in xrange(args.num_epochs):
            # Anneal learning rate
            new_lr = args.learning_rate * (args.decay_rate ** epoch)
            sess.run(tf.assign(generator.lr, new_lr))
            batcher.reset_batch_pointer()
            state = generator.initial_state.eval()

            for batch in xrange(batcher.num_batches):
                start = time.time()
                x, y  = batcher.next_batch()
                feed  = {generator.input_data: x, generator.targets: y, generator.initial_state: state}
                # train_loss, state, _ = sess.run([generator.cost, generator.final_state, generator.train_op], feed)
                train_loss, _ = sess.run([generator.cost, generator.train_op], feed)
                end   = time.time()

                print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \
                    .format(epoch * batcher.num_batches + batch,
                        args.num_epochs * batcher.num_batches,
                        epoch, train_loss, end - start)

                if (epoch * batcher.num_batches + batch) % args.save_every == 0:
                    checkpoint_path = os.path.join(args.save_dir_gen, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch)
                    print 'Generator model saved to {}'.format(checkpoint_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号