def main(config):
cudnn.benchmark = True
data_loader = get_loader(image_path=config.image_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers)
solver = Solver(config, data_loader)
# Create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
# Train and sample the images
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
评论列表
文章目录