def make_sprite(label_img, save_path):
import math
import torch
import torchvision
# this ensures the sprite image has correct dimension as described in
# https://www.tensorflow.org/get_started/embedding_viz
nrow = int(math.ceil((label_img.size(0)) ** 0.5))
# augment images so that #images equals nrow*nrow
label_img = torch.cat((label_img, torch.randn(nrow ** 2 - label_img.size(0), *label_img.size()[1:]) * 255), 0)
# Dirty fix: no pixel are appended by make_grid call in save_image (https://github.com/pytorch/vision/issues/206)
xx = torchvision.utils.make_grid(torch.Tensor(1, 3, 32, 32), padding=0)
if xx.size(2) == 33:
sprite = torchvision.utils.make_grid(label_img, nrow=nrow, padding=0)
sprite = sprite[:, 1:, 1:]
torchvision.utils.save_image(sprite, os.path.join(save_path, 'sprite.png'))
else:
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)
评论列表
文章目录