iterator.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:additions_mxnet 作者: eldercrow 项目源码 文件源码
def _data_augmentation(self, data, label):
        """
        perform data augmentations: crop, mirror, resize, sub mean, swap channels...
        """
        if self.is_train and self._rand_sampler:
            width = data.shape[1]
            height = data.shape[0]
            rand_crop = self._rand_sampler.sample(label, (height, width))
            xmin, ymin, xmax, ymax = np.array(rand_crop[0]).astype(int)
            data = crop_roi_patch(data.asnumpy(), (xmin, ymin, xmax, ymax))
            label = rand_crop[1]
        if self.is_train:
            interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, \
                              cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
        else:
            interp_methods = [cv2.INTER_LINEAR]
        interp_method = interp_methods[int(np.random.uniform(0, 1) * len(interp_methods))]
        data = mx.img.imresize(data, self._data_shape[1], self._data_shape[0], interp_method)
        if self._rand_eraser and self.is_train:
            label_scaler = np.array((self._data_shape[0], self._data_shape[1]))
            label_scaler = np.tile(np.reshape(label_scaler, (1, -1)), (1, 2))
            data = mx.nd.array(self._rand_eraser.sample(data.asnumpy(), label[:, 1:] * label_scaler))
        if self.is_train:
            valid_mask = np.where(np.any(label != -1, axis=1))[0]
            if self._rand_mirror:
                rr = rand_crop[2]
                if np.random.uniform(0, 1) > 0.5:
                    data = mx.nd.flip(data, axis=1)
                    tmp = rr - label[valid_mask, 1]
                    label[valid_mask, 1] = rr - label[valid_mask, 3]
                    label[valid_mask, 3] = tmp
            # label[valid_mask, 1::2] *= data.shape[1]
            # label[valid_mask, 2::2] *= data.shape[0]
        data = mx.nd.transpose(data, (2,0,1))
        data = data.astype('float32')
        data = data - self._mean_pixels
        return data, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号