def setup(opt):
'''
Setups cudnn, seeds and parses updates string.
'''
opt.cuda = not opt.cpu
torch.set_num_threads(4)
if opt.nc is None:
opt.nc = 1 if opt.dataset == 'mnist' else 3
try:
os.makedirs(opt.save_dir)
except OSError:
print('Directory was not created.')
if opt.manual_seed is None:
opt.manual_seed = random.randint(1, 10000)
print("Random Seed: ", opt.manual_seed)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed_all(opt.manual_seed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device,"
"so you should probably run with --cuda")
updates = {'e': {}, 'g': {}}
updates['e']['num_updates'] = int(opt.e_updates.split(';')[0])
updates['e'].update({x.split(':')[0]: float(x.split(':')[1])
for x in opt.e_updates.split(';')[1].split(',')})
updates['g']['num_updates'] = int(opt.g_updates.split(';')[0])
updates['g'].update({x.split(':')[0]: float(x.split(':')[1])
for x in opt.g_updates.split(';')[1].split(',')})
return updates
评论列表
文章目录