utils_pro.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:cav_gcnn 作者: myinxd 项目源码 文件源码
def gen_sample_multi(img, img_mark, rate=0.2, boxsize=10, px_over=5):
    """
    Generate samples by splitting the pixel classified image with
    provided boxsize

    Input
    -----
    img: np.ndarray
        The 2D raw image
    img_mark: np.ndarray
        The 2D marked image
    rate: float
        The rate of cavity pixels in the box, belongs to (0,1), default as 0.5
    boxsize: integer
        Size of the box, default as 10
    px_over: integer
        Overlapped pixels, default as 5

    Output
    ------
    data: np.ndarray
        The matrix holding samples, each column represents one sample
    label: np.ndarray
        Labels with respect to samples, could be 0, 1, and 2.
    """
    # Init
    rows, cols = img.shape
    px_diff = boxsize - px_over
    # Number of boxes
    box_rows = int(np.round((rows - boxsize - 1) / px_diff)) + 1
    box_cols = int(np.round((cols - boxsize - 1) / px_diff)) + 1
    # init data and label
    data = np.zeros((box_rows * box_cols, boxsize**2))
    label = np.zeros((box_rows * box_cols, 1))

    # Split
    for i in range(box_rows):
        for j in range(box_cols):
            sample = img[i * px_diff:i * px_diff + boxsize,
                         j * px_diff:j * px_diff + boxsize]
            label_mat = img_mark[i * px_diff:i * px_diff + boxsize,
                                 j * px_diff:j * px_diff + boxsize]
            data[i * box_rows + j, :] = sample.reshape((boxsize**2, ))
            # get label (modified: 2017/02/22)
            mask = label_mat.reshape((boxsize**2,))
            mask0 = len(np.where(mask == 0)[0])
            mask1 = len(np.where(mask == 127)[0])
            mask2 = len(np.where(mask == 255)[0])
            try:
                r = mask1 / (len(mask))
            except ZeroDivisionError:
                r = 1
            if r >= rate:
                label[i * box_rows + j, 0] = 1
            else:
                mask_mat = np.array([mask0, mask1, mask2])
                l = np.where(mask_mat == mask_mat.max())[0][0]
                # label[i*box_rows+j,0] = np.where(hist==hist.max())[0][0]
                label[i * box_rows + j, 0] = l

    return data, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号