snapshot.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:keras-contrib 作者: farizrahman4u 项目源码 文件源码
def get_callbacks(self, model_prefix='Model'):
        """
        Creates a list of callbacks that can be used during training to create a
        snapshot ensemble of the model.

        Args:
            model_prefix: prefix for the filename of the weights.

        Returns: list of 3 callbacks [ModelCheckpoint, LearningRateScheduler,
                 SnapshotModelCheckpoint] which can be provided to the 'fit' function
        """
        if not os.path.exists('weights/'):
            os.makedirs('weights/')

        callback_list = [ModelCheckpoint('weights/%s-Best.h5' % model_prefix, monitor='val_acc',
                                         save_best_only=True, save_weights_only=True),
                         LearningRateScheduler(schedule=self._cosine_anneal_schedule),
                         SnapshotModelCheckpoint(self.T, self.M, fn_prefix='weights/%s' % model_prefix)]

        return callback_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号