main.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch-tutorial 作者: yunjey 项目源码 文件源码
def main(config):
    cudnn.benchmark = True

    data_loader = get_loader(image_path=config.image_path,
                             image_size=config.image_size,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers)

    solver = Solver(config, data_loader)

    # Create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    # Train and sample the images
    if config.mode == 'train':
        solver.train()
    elif config.mode == 'sample':
        solver.sample()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号