def make_grid(I, ncols=8):
assert isinstance(I, np.ndarray), 'plugin error, should pass numpy array here'
assert I.ndim == 4 and I.shape[1] == 3
nimg = I.shape[0]
H = I.shape[2]
W = I.shape[3]
ncols = min(nimg, ncols)
nrows = int(np.ceil(float(nimg) / ncols))
canvas = np.zeros((3, H * nrows, W * ncols))
i = 0
for y in range(nrows):
for x in range(ncols):
if i >= nimg:
break
canvas[:, y * H:(y + 1) * H, x * W:(x + 1) * W] = I[i]
i = i + 1
return canvas
评论列表
文章目录