def next(self):
if self.index + self.batch_size < len(self.sequence_info):
batch_info = self.sequence_info[self.index:self.index + self.batch_size]
filenames = [os.path.join(PoseNetInputProvider.BASE_DIR, filename) for filename in batch_info[:, 0]]
rgb_files = map(read_rgb_image, filenames)
mean = get_mean(rgb_files)
self.logger.info('Mean for current batch:{}'.format(mean))
rgb_files = [rgb_file - mean for rgb_file in rgb_files]
groundtruths = batch_info[:, 1:]
batch = PoseNetInputProvider.PoseNetBatch()
batch.rgb_files = rgb_files
batch.groundtruths = groundtruths
batch.rgb_filenames = filenames
self.index += self.batch_size
return batch
else:
raise StopIteration()
评论列表
文章目录