def grid_vis(X, nh, nw): #[buggy]
if X.shape[0] == 1:
return X[0]
# nc = 3
if X.ndim == 3:
X = X[..., np.newaxis]
if X.shape[-1] == 1:
X = np.tile(X, [1,1,1,3])
h, w = X[0].shape[:2]
if X.dtype == np.uint8:
img = np.ones((h * nh, w * nw, 3), np.uint8) * 255
else:
img = np.ones((h * nh, w * nw, 3), X.dtype)
for n, x in enumerate(X):
j = n // nw
i = n % nw
img[j * h:j * h + h, i * w:i * w + w, :] = x
img = np.squeeze(img)
return img
评论列表
文章目录