def test_ImagePatchesByMask_3channel():
img = np.reshape(np.arange(25 * 3), (5, 5, 3))
mask = np.eye(5, dtype='uint8') * 255
samples = [(img, mask)]
np.random.seed(0)
get_patches = ImagePatchesByMask(0, 1, (3, 3), 1, 1, retlabel=False)
patches = samples >> get_patches >> Collect()
assert len(patches) == 2
p, m = patches[0]
img_patch0 = np.array([[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
[[51, 52, 53], [54, 55, 56], [57, 58, 59]],
[[66, 67, 68], [69, 70, 71], [72, 73, 74]]])
mask_patch0 = np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255]])
nt.assert_allclose(p, img_patch0)
nt.assert_allclose(m, mask_patch0)
评论列表
文章目录