trainer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def prepare_model_dir(self):
        if self.config.load_path:
            self.model_dir=self.config.load_path
        else:
            pth=datetime.now().strftime("%m%d_%H%M%S")+'_'+self.data_type
            self.model_dir=os.path.join(self.config.model_dir,pth)


        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)
        print('Model directory is ',self.model_dir)

        self.save_model_dir=os.path.join(self.model_dir,'checkpoints')
        if not os.path.exists(self.save_model_dir):
            os.mkdir(self.save_model_dir)
        self.save_model_name=os.path.join(self.save_model_dir,'Model')


        param_path = os.path.join(self.model_dir, "params.json")
        print("[*] MODEL dir: %s" % self.model_dir)
        print("[*] PARAM path: %s" % param_path)
        with open(param_path, 'w') as fp:
            json.dump(self.config.__dict__, fp, indent=4, sort_keys=True)

        config=self.config
        if config.is_train and not config.load_path:
            config.log_code_dir=os.path.join(self.model_dir,'code')
            for path in [self.model_dir, config.log_code_dir]:
                if not os.path.exists(path):
                    os.makedirs(path)

            #Copy python code in directory into model_dir/code for future reference:
            code_dir=os.path.dirname(os.path.realpath(sys.argv[0]))
            model_files = [f for f in listdir(code_dir) if isfile(join(code_dir, f))]
            for f in model_files:
                if f.endswith('.py'):
                    shutil.copy2(f,config.log_code_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号