def train(self):
"""
Train the model on the training set.
A checkpoint of the model is saved after each epoch
and if the validation accuracy is improved upon,
a separate ckpt is created for use on the test set.
"""
# switch to train mode for dropout
self.model.train()
# load the most recent checkpoint
if self.resume:
self.load_checkpoint(best=False)
for epoch in trange(self.start_epoch, self.epochs):
# decay learning rate
if self.is_decay:
self.anneal_learning_rate(epoch)
# train for 1 epoch
self.train_one_epoch(epoch)
# evaluate on validation set
valid_acc = self.validate(epoch)
is_best = valid_acc > self.best_valid_acc
self.best_valid_acc = max(valid_acc, self.best_valid_acc)
self.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.model.state_dict(),
'best_valid_acc': self.best_valid_acc}, is_best)
评论列表
文章目录