def train_generator(args, load_recent=True):
'''Train the generator via classical approach'''
logging.debug('Batcher...')
batcher = Batcher(args.data_dir, args.batch_size, args.seq_length)
logging.debug('Vocabulary...')
with open(os.path.join(args.save_dir_gen, 'config.pkl'), 'w') as f:
cPickle.dump(args, f)
with open(os.path.join(args.save_dir_gen, 'real_beer_vocab.pkl'), 'w') as f:
cPickle.dump((batcher.chars, batcher.vocab), f)
logging.debug('Creating generator...')
generator = Generator(args, is_training = True)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
if load_recent:
ckpt = tf.train.get_checkpoint_state(args.save_dir_gen)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for epoch in xrange(args.num_epochs):
# Anneal learning rate
new_lr = args.learning_rate * (args.decay_rate ** epoch)
sess.run(tf.assign(generator.lr, new_lr))
batcher.reset_batch_pointer()
state = generator.initial_state.eval()
for batch in xrange(batcher.num_batches):
start = time.time()
x, y = batcher.next_batch()
feed = {generator.input_data: x, generator.targets: y, generator.initial_state: state}
# train_loss, state, _ = sess.run([generator.cost, generator.final_state, generator.train_op], feed)
train_loss, _ = sess.run([generator.cost, generator.train_op], feed)
end = time.time()
print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \
.format(epoch * batcher.num_batches + batch,
args.num_epochs * batcher.num_batches,
epoch, train_loss, end - start)
if (epoch * batcher.num_batches + batch) % args.save_every == 0:
checkpoint_path = os.path.join(args.save_dir_gen, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch)
print 'Generator model saved to {}'.format(checkpoint_path)
评论列表
文章目录