GA.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def train_process(self):
        client = GAClient.Client()
        for model in self.population.values():
            # if getattr(model, 'parent', None) is not None:
            # has parents means muatetion and weight change, so need to save weights
            keras.models.save_model(model.model, model.config.model_path)
            model.graph.save_params(model.config.output_path+'/graph.json')

            kwargs = dict(
                name=model.config.name,
                epochs=model.config.epochs,
                verbose=model.config.verbose,
                limit_data=model.config.limit_data,
                dataset_type=model.config.dataset_type
            )
            if parallel:
                client.run_self(kwargs)
            else:
                name, score = GAClient.run(**kwargs)
                setattr(self.population[name], 'score', score)

        if parallel:
            client.wait()
            for name, score in client.scores.items():
                setattr(self.population[name], 'score', score)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号