def _run(self):
m, mvalid, mtest = self.train_model, self.validation_model, self.test_model
config = self.config
data = self.data
params = self.params
init_op = tf.initialize_all_variables()
with tf.Session() as session:
session.run(init_op)
print("Starting training from epoch %d using %s loss" % (config.epoch, m.loss_fct))
while config.epoch <= config.max_max_epoch:
i = config.epoch
lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
m.assign_lr(session, config.learning_rate * lr_decay)
print("\nEpoch: %d Learning rate: %.3f" % (i, session.run(m.lr)))
train_perplexity = run_epoch(session, m,
data.train,
eval_op=m.train_op,
verbose=True,
opIO=self.io,
log_rate=params.log_rate,
save_rate=params.save_rate)
print("Epoch: %d Train Perplexity: %.3f" % (i, train_perplexity))
print("Validation using %s loss" % mvalid.loss_fct)
valid_perplexity = run_epoch(session, mvalid, data.valid)
print("Epoch: %d Valid Perplexity: %.3f" % (i, valid_perplexity))
config.step = 0
config.epoch += 1
config.save()
self.io.save_checkpoint(session, "ep_%d.ckpt" % config.epoch)
评论列表
文章目录