def get_batch_idx(self, idx, **kwargs):
if self.mode == 'train':
new_idx = []
# self.log.info('Label IDX: {}'.format(idx))
if self.stats_provider is None:
label_ids = [ii % self._real_size for ii in idx]
else:
# print idx, self.stats_provider.get_size()
stats_batch = self.stats_provider.get_batch_idx(idx)
label_ids = []
for ii in xrange(len(idx)):
label_ids.append(np.argmax(stats_batch['y_gt'][ii]))
for ii in label_ids:
data_group = self.data_provider.label_idx[ii]
num_ids = len(data_group)
kk = int(np.floor(self.rnd.uniform(0, num_ids)))
new_idx.append(data_group[kk])
else:
new_idx = idx
return self.data_provider.get_batch_idx(new_idx)
评论列表
文章目录