label_sample_data_provider.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def get_batch_idx(self, idx, **kwargs):
        if self.mode == 'train':
            new_idx = []
            # self.log.info('Label IDX: {}'.format(idx))
            if self.stats_provider is None:
                label_ids = [ii % self._real_size for ii in idx]
            else:
                # print idx, self.stats_provider.get_size()
                stats_batch = self.stats_provider.get_batch_idx(idx)
                label_ids = []
                for ii in xrange(len(idx)):
                    label_ids.append(np.argmax(stats_batch['y_gt'][ii]))

            for ii in label_ids:
                data_group = self.data_provider.label_idx[ii]
                num_ids = len(data_group)
                kk = int(np.floor(self.rnd.uniform(0, num_ids)))
                new_idx.append(data_group[kk])
        else:
            new_idx = idx
        return self.data_provider.get_batch_idx(new_idx)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号