def __init__(self, folder, chain, train, test, batchsize=500, resume=True, gpu=0, nepoch=1, reports=[]):
self.reports = reports
self.nepoch = nepoch
self.folder = folder
self.chain = chain
self.gpu = gpu
if self.gpu >= 0:
chainer.cuda.get_device(gpu).use()
chain.to_gpu(gpu)
self.eval_chain = eval_chain = chain.copy()
self.chain.test = False
self.eval_chain.test = True
self.testset = test
if not os.path.exists(folder):
os.makedirs(folder)
train_iter = chainer.iterators.SerialIterator(train, batchsize, shuffle=True)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, chain.optimizer, device=gpu)
trainer = training.Trainer(updater, (nepoch, 'epoch'), out=folder)
# trainer.extend(TrainingModeSwitch(chain))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.Evaluator(test_iter, eval_chain, device=gpu), trigger=(1,'epoch'))
trainer.extend(extensions.snapshot_object(
chain, 'chain_snapshot_epoch_{.updater.epoch:06}'), trigger=(1,'epoch'))
trainer.extend(extensions.snapshot(
filename='snapshot_epoch_{.updater.epoch:06}'), trigger=(1,'epoch'))
trainer.extend(extensions.LogReport(trigger=(1,'epoch')), trigger=(1,'iteration'))
trainer.extend(extensions.PrintReport(
['epoch']+reports), trigger=IntervalTrigger(1,'epoch'))
self.trainer = trainer
if resume:
#if resumeFrom is not None:
# trainerFile = os.path.join(resumeFrom[0],'snapshot_epoch_{:06}'.format(resumeFrom[1]))
# S.load_npz(trainerFile, trainer)
i = 1
trainerFile = os.path.join(folder,'snapshot_epoch_{:06}'.format(i))
while i <= nepoch and os.path.isfile(trainerFile):
i = i + 1
trainerFile = os.path.join(folder,'snapshot_epoch_{:06}'.format(i))
i = i - 1
trainerFile = os.path.join(folder,'snapshot_epoch_{:06}'.format(i))
if i >= 0 and os.path.isfile(trainerFile):
S.load_npz(trainerFile, trainer)
评论列表
文章目录