def train(self, logdir: Path, train_ids: List[str], valid_ids: List[str],
validation: str, no_mp: bool=False, valid_only: bool=False,
model_path: Path=None):
self.tb_logger = tensorboard_logger.Logger(str(logdir))
self.logdir = logdir
train_images = [self.load_image(im_id) for im_id in sorted(train_ids)]
valid_images = None
if model_path:
self.restore_snapshot(model_path)
start_epoch = int(model_path.name.rsplit('-', 1)[1]) + 1
else:
start_epoch = self.restore_last_snapshot(logdir)
square_validation = validation == 'square'
lr = self.hps.lr
self.optimizer = self._init_optimizer(lr)
for n_epoch in range(start_epoch, self.hps.n_epochs):
if self.hps.lr_decay:
if n_epoch % 2 == 0 or n_epoch == start_epoch:
lr = self.hps.lr * self.hps.lr_decay ** n_epoch
self.optimizer = self._init_optimizer(lr)
else:
lim_1, lim_2 = 25, 50
if n_epoch == lim_1 or (
n_epoch == start_epoch and n_epoch > lim_1):
lr = self.hps.lr / 5
self.optimizer = self._init_optimizer(lr)
if n_epoch == lim_2 or (
n_epoch == start_epoch and n_epoch > lim_2):
lr = self.hps.lr / 25
self.optimizer = self._init_optimizer(lr)
logger.info('Starting epoch {}, step {:,}, lr {:.8f}'.format(
n_epoch + 1, self.net.global_step[0], lr))
subsample = 1 if valid_only else 2 # make validation more often
for _ in range(subsample):
if not valid_only:
self.train_on_images(
train_images,
subsample=subsample,
square_validation=square_validation,
no_mp=no_mp)
if valid_images is None:
if square_validation:
s = self.hps.validation_square
valid_images = [
Image(None, im.data[:, :s, :s], im.mask[:, :s, :s])
for im in train_images]
else:
valid_images = [self.load_image(im_id)
for im_id in sorted(valid_ids)]
if valid_images:
self.validate_on_images(valid_images, subsample=1)
if valid_only:
break
self.save_snapshot(n_epoch)
self.tb_logger = None
self.logdir = None
评论列表
文章目录