def build_image_grid(image_batch, row, col):
"""
Build an image grid from an image batch.
"""
image_size = FLAGS.image_size
grid = tf.reshape(
image_batch, [1, row * col * image_size, image_size, 3])
grid = tf.split(grid, col, axis=1)
grid = tf.concat(grid, axis=2)
grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)
grid = tf.reshape(grid, [row * image_size, col * image_size, 3])
return grid
评论列表
文章目录