def getBatch_(self, indices):
# format NxCHxWxH
batchRGB = np.zeros((len(indices), self.CH, self.W, self.H), dtype='float32')
batchLabel = np.zeros((len(indices), self.W, self.H), dtype='int32')
k = 0
for i in indices:
(rgbname, gtname) = self.flist[i]
# format: HxWxCH
rgb = misc.imread(rgbname)
if(gtname.endswith('.png')):
gt = misc.imread(gtname)
else:
gt = np.loadtxt(gtname)
gt = gt.astype('uint8')
if(self.data_transformer is not None):
rgb = self.data_transformer.transformData(rgb)
gt = self.data_transformer.transformLabel(gt)
#^ data_transformer outputs in format HxWxCH
# convertion from HxWxCH to CHxWxH
batchRGB[k,:,:,:] = rgb.astype(np.float32).transpose((2,1,0))
batchLabel[k,:,:] = gt.astype(np.int32).transpose((1,0))
k += 1
#ipdb.set_trace()
if(self.weights_classes_flag):
return (batchRGB, batchLabel, self.weights_classes)
else:
return (batchRGB, batchLabel)
评论列表
文章目录