def __call__(self, trainer):
duration_epoch = time.time() - self.time_epoch
epoch = trainer.updater.epoch
loss = trainer.observation["main/loss"].data
logger.info("epoch: %s, duration: %ds, loss: %.6g.",
epoch, duration_epoch, loss)
# get rnn state
model = trainer.updater.get_optimizer("main").target
state = model.predictor.get_state()
# generate text
seed = generate_seed(self.text)
generate_text(model, seed)
# set rnn back to training state
model.predictor.set_state(state)
# reset time
self.time_epoch = time.time()
评论列表
文章目录