main.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DenseNet 作者: kevinzakka 项目源码 文件源码
def main(config):

    # ensure directories are setup
    prepare_dirs(config)

    if config.num_gpu > 0:
        torch.cuda.manual_seed(config.random_seed)
        kwargs = {'num_workers': 1, 'pin_memory': True}
    else:
        torch.manual_seed(config.random_seed)
        kwargs = {}

    # instantiate data loaders
    if config.is_train:
        data_loader = get_train_valid_loader(config.data_dir,
            config.dataset, config.batch_size, config.augment, 
            config.random_seed, config.valid_size, config.shuffle, 
            config.show_sample, **kwargs)
    else:
        data_loader = get_test_loader(config.data_dir,
            config.dataset, config.batch_size, config.shuffle, 
            **kwargs)

    # instantiate trainer
    trainer = Trainer(config, data_loader)

    # either train
    if config.is_train:
        save_config(config)
        trainer.train()

    # or load a pretrained model and test
    else:
        trainer.test()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号