train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:kaggle-dstl 作者: lopuhin 项目源码 文件源码
def train(self, logdir: Path, train_ids: List[str], valid_ids: List[str],
              validation: str, no_mp: bool=False, valid_only: bool=False,
              model_path: Path=None):
        self.tb_logger = tensorboard_logger.Logger(str(logdir))
        self.logdir = logdir
        train_images = [self.load_image(im_id) for im_id in sorted(train_ids)]
        valid_images = None
        if model_path:
            self.restore_snapshot(model_path)
            start_epoch = int(model_path.name.rsplit('-', 1)[1]) + 1
        else:
            start_epoch = self.restore_last_snapshot(logdir)
        square_validation = validation == 'square'
        lr = self.hps.lr
        self.optimizer = self._init_optimizer(lr)
        for n_epoch in range(start_epoch, self.hps.n_epochs):
            if self.hps.lr_decay:
                if n_epoch % 2 == 0 or n_epoch == start_epoch:
                    lr = self.hps.lr * self.hps.lr_decay ** n_epoch
                    self.optimizer = self._init_optimizer(lr)
            else:
                lim_1, lim_2 = 25, 50
                if n_epoch == lim_1 or (
                        n_epoch == start_epoch and n_epoch > lim_1):
                    lr = self.hps.lr / 5
                    self.optimizer = self._init_optimizer(lr)
                if n_epoch == lim_2 or (
                        n_epoch == start_epoch and n_epoch > lim_2):
                    lr = self.hps.lr / 25
                    self.optimizer = self._init_optimizer(lr)
            logger.info('Starting epoch {}, step {:,}, lr {:.8f}'.format(
                n_epoch + 1, self.net.global_step[0], lr))
            subsample = 1 if valid_only else 2  # make validation more often
            for _ in range(subsample):
                if not valid_only:
                    self.train_on_images(
                        train_images,
                        subsample=subsample,
                        square_validation=square_validation,
                        no_mp=no_mp)
                if valid_images is None:
                    if square_validation:
                        s = self.hps.validation_square
                        valid_images = [
                            Image(None, im.data[:, :s, :s], im.mask[:, :s, :s])
                            for im in train_images]
                    else:
                        valid_images = [self.load_image(im_id)
                                        for im_id in sorted(valid_ids)]
                if valid_images:
                    self.validate_on_images(valid_images, subsample=1)
            if valid_only:
                break
            self.save_snapshot(n_epoch)
        self.tb_logger = None
        self.logdir = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号