def stylize(args):
content_image = utils.load_image(args.content_image, scale=args.content_scale)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0)
if args.cuda:
content_image = content_image.cuda()
content_image = Variable(content_image, volatile=True)
style_model = TransformerNet()
style_model.load_state_dict(torch.load(args.model))
if args.cuda:
style_model.cuda()
output = style_model(content_image)
if args.cuda:
output = output.cpu()
output_data = output.data[0]
utils.save_image(args.output_image, output_data)
评论列表
文章目录