b3_data_iter.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:kaggle-dstl-satellite-imagery-feature-detection 作者: u1234x1234 项目源码 文件源码
def __init__(self, masks):
        n_class = 10
        self.maps_with_class = [[], [], [], [], [], [], [], [], [], []]
        self.kde_samplers = []
        self.class_probs = np.ones(n_class) / n_class
#        self.class_probs = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0.5, 0.5])
        self.mask_size = None
        ts = time.time()
        for mask_i, mask in enumerate(masks):
            assert mask.shape[2] == n_class
            if not self.mask_size:
                self.mask_size = mask.shape[1]
            samplers = []
            for class_i in range(n_class):
                X = np.nonzero(mask[:, :, class_i])
                X = np.stack(X, axis=1)

#                np.random.shuffle(X)
#                X = X[:50000]

                if not X.size:
                    samplers.append(None)
                else:
                    self.maps_with_class[class_i].append(mask_i)
                    sampler = neighbors.KernelDensity(self.mask_size * 0.02).fit(X)
                    samplers.append(sampler)

            assert len(samplers) == n_class
            self.kde_samplers.append(samplers)
        print('sampler init time: {}'.format(time.time() - ts))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号