training.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:EUSIPCO2017 作者: Veleslavia 项目源码 文件源码
def train(self):
        model = self.model_module.build_model(IRMAS_N_CLASSES)

        early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOPPING_EPOCH)
        save_clb = ModelCheckpoint(
            "{weights_basepath}/{model_path}/".format(
                weights_basepath=MODEL_WEIGHT_BASEPATH,
                model_path=self.model_module.BASE_NAME) +
            "epoch.{epoch:02d}-val_loss.{val_loss:.3f}-fbeta.{val_fbeta_score:.3f}"+"-{key}.hdf5".format(
                key=self.model_module.MODEL_KEY),
            monitor='val_loss',
            save_best_only=True)
        lrs = LearningRateScheduler(lambda epoch_n: self.init_lr / (2**(epoch_n//SGD_LR_REDUCE)))
        model.summary()
        model.compile(optimizer=self.optimizer,
                      loss='categorical_crossentropy',
                      metrics=['accuracy', fbeta_score])

        history = model.fit_generator(self._batch_generator(self.X_train, self.y_train),
                                      samples_per_epoch=self.model_module.SAMPLES_PER_EPOCH,
                                      nb_epoch=MAX_EPOCH_NUM,
                                      verbose=2,
                                      callbacks=[save_clb, early_stopping, lrs],
                                      validation_data=self._batch_generator(self.X_val, self.y_val),
                                      nb_val_samples=self.model_module.SAMPLES_PER_VALIDATION,
                                      class_weight=None,
                                      nb_worker=1)

        pickle.dump(history.history, open('{history_basepath}/{model_path}/history_{model_key}.pkl'.format(
            history_basepath=MODEL_HISTORY_BASEPATH,
            model_path=self.model_module.BASE_NAME,
            model_key=self.model_module.MODEL_KEY),
            'w'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号