test_transformer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nuts-ml 作者: maet3608 项目源码 文件源码
def test_ImagePatchesByMask_3channel():
    img = np.reshape(np.arange(25 * 3), (5, 5, 3))
    mask = np.eye(5, dtype='uint8') * 255
    samples = [(img, mask)]

    np.random.seed(0)
    get_patches = ImagePatchesByMask(0, 1, (3, 3), 1, 1, retlabel=False)
    patches = samples >> get_patches >> Collect()
    assert len(patches) == 2

    p, m = patches[0]
    img_patch0 = np.array([[[36, 37, 38], [39, 40, 41], [42, 43, 44]],
                           [[51, 52, 53], [54, 55, 56], [57, 58, 59]],
                           [[66, 67, 68], [69, 70, 71], [72, 73, 74]]])
    mask_patch0 = np.array([[255, 0, 0], [0, 255, 0], [0, 0, 255]])
    nt.assert_allclose(p, img_patch0)
    nt.assert_allclose(m, mask_patch0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号