def main(config):
svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)
solver = Solver(config, svhn_loader, mnist_loader)
cudnn.benchmark = True
# create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
if config.mode == 'train':
solver.train(svhn_test_loader, mnist_test_loader)
elif config.mode == 'sample':
solver.sample()
评论列表
文章目录