utils.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:AGE 作者: DmitryUlyanov 项目源码 文件源码
def setup(opt):
    '''
    Setups cudnn, seeds and parses updates string.
    '''
    opt.cuda = not opt.cpu

    torch.set_num_threads(4)

    if opt.nc is None:
        opt.nc = 1 if opt.dataset == 'mnist' else 3

    try:
        os.makedirs(opt.save_dir)
    except OSError:
        print('Directory was not created.')

    if opt.manual_seed is None:
        opt.manual_seed = random.randint(1, 10000)

    print("Random Seed: ", opt.manual_seed)
    random.seed(opt.manual_seed)
    torch.manual_seed(opt.manual_seed)
    torch.cuda.manual_seed_all(opt.manual_seed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print("WARNING: You have a CUDA device,"
              "so you should probably run with --cuda")

    updates = {'e': {}, 'g': {}}
    updates['e']['num_updates'] = int(opt.e_updates.split(';')[0])
    updates['e'].update({x.split(':')[0]: float(x.split(':')[1])
                         for x in opt.e_updates.split(';')[1].split(',')})

    updates['g']['num_updates'] = int(opt.g_updates.split(';')[0])
    updates['g'].update({x.split(':')[0]: float(x.split(':')[1])
                         for x in opt.g_updates.split(';')[1].split(',')})

    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号