def save_images_grid(imgs, path, grid_w=4, grid_h=4, post_processing=postprocessing_tanh, transposed=False):
imgs = copy_to_cpu(imgs)
if post_processing is not None:
imgs = post_processing(imgs)
b, ch, w, h = imgs.shape
assert b == grid_w*grid_h
imgs = imgs.reshape((grid_w, grid_h, ch, w, h))
imgs = imgs.transpose(0, 1, 3, 4, 2)
if transposed:
imgs = imgs.reshape((grid_w, grid_h, w, h, ch)).transpose(1, 2, 0, 3, 4).reshape((grid_h*w, grid_w*h, ch))
else:
imgs = imgs.reshape((grid_w, grid_h, w, h, ch)).transpose(0, 2, 1, 3, 4).reshape((grid_w*w, grid_h*h, ch))
if ch==1:
imgs = imgs.reshape((grid_w*w, grid_h*h))
cv2.imwrite(path, imgs)
评论列表
文章目录