def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
if cfg.IS_RPN:
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()
db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
self._cur += cfg.TRAIN.IMS_PER_BATCH
else:
# sample images
db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32)
i = 0
while (i < cfg.TRAIN.IMS_PER_BATCH):
ind = self._perm[self._cur]
num_objs = self._roidb[ind]['boxes'].shape[0]
if num_objs != 0:
db_inds[i] = ind
i += 1
self._cur += 1
if self._cur >= len(self._roidb):
self._shuffle_roidb_inds()
return db_inds
评论列表
文章目录