def group_batch_images(x):
sz = x.get_shape().as_list()
num_cols = int(math.sqrt(sz[0]))
img = tf.slice(x, [0,0,0,0],[num_cols ** 2, -1, -1, -1])
img = tf.batch_to_space(img, [[0,0],[0,0]], num_cols)
return img
文章目录