trainer.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:minos 作者: guybedo 项目源码 文件源码
def train(self, blueprint, device,
              save_best_model=False, model_filename=None):
        try:
            model = self.model_builder.build(
                blueprint,
                device)
            setup_tf_session(device)
            nb_epoch, callbacks = self._get_stopping_parameters(blueprint)
            if save_best_model:
                callbacks.append(self._get_model_save_callback(
                    model_filename,
                    blueprint.training.metric.metric))
            start = time()
            history = model.fit_generator(
                self.batch_iterator,
                self.batch_iterator.samples_per_epoch,
                nb_epoch,
                callbacks=callbacks,
                validation_data=self.test_batch_iterator,
                nb_val_samples=self.test_batch_iterator.sample_count)
            if save_best_model:
                del model
                model = load_keras_model(model_filename)
            return model, history, (time() - start)
        except Exception as ex:
            logging.debug(ex)
            logging.debug(traceback.format_exc())
        try:
            from keras import backend
            backend.clear_session()
        except:
            logging.debug(ex)
            logging.debug(traceback.format_exc())
        return None, None, 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号