fr3dnet_trainer.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:kaggle_dsb 作者: syagev 项目源码 文件源码
def train(self, X_train, X_val):

        train_true = filter(lambda x: x[2]==1, X_train)
        train_false = filter(lambda x: x[2]==0, X_train)

        val_true = filter(lambda x: x[2]==1, X_val)
        val_false = filter(lambda x: x[2]==0, X_val)

        n_train_true = len(train_true)
        n_val_true = len(val_true)

        make_epoch_helper = functools.partial(make_epoch, train_true=train_true, train_false=train_false, val_true=val_true, val_false=val_false)

        logging.info("Starting training...")
        epoch_iterator = ParallelBatchIterator(make_epoch_helper, range(P.N_EPOCHS), ordered=False, batch_size=1, multiprocess=False, n_producers=1)

        for epoch_values in epoch_iterator:
            self.pre_epoch()
            train_epoch_data, val_epoch_data = epoch_values

            train_epoch_data = util.chunks(train_epoch_data, P.BATCH_SIZE_TRAIN)
            val_epoch_data = util.chunks(val_epoch_data, P.BATCH_SIZE_VALIDATION)

            self.do_batches(self.train_fn, train_epoch_data, self.train_metrics)
            self.do_batches(self.val_fn, val_epoch_data, self.val_metrics)

            self.post_epoch()
            logging.info("Setting learning rate to {}".format(P.LEARNING_RATE  * ((0.985)**self.epoch)))
            self.l_r.set_value(P.LEARNING_RATE  * ((0.985)**self.epoch))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号