def combine_images(generated_images):
total, width, height, ch = generated_images.shape
cols = int(math.sqrt(total))
rows = math.ceil(float(total)/cols)
combined_image = np.zeros((height*rows, width*cols, 3),
dtype = generated_images.dtype)
for index, image in enumerate(generated_images):
i = int(index/cols)
j = index % cols
combined_image[width*i:width*(i+1), height*j:height*(j+1), :]\
= image
return combined_image
评论列表
文章目录