def next(self):
nrow=0
ncol=0
crop_size=self._crop_size
while (nrow<crop_size or ncol<crop_size) \
and self.cur_batch < self.batch_num:
img_path=os.path.join(self._datadir, self._img_list[self.cur_batch])
img=cv2.imread(img_path, cv2.IMREAD_COLOR)
# img=cv2.cvtColor(img,cv2.COLOR_BGR2YCR_CB)
nrow,ncol=img.shape[0:2]
self.cur_batch+=1
if self.cur_batch < self.batch_num:
# img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img_ds=cv2.resize(img,(ncol/self._scale_factor, nrow/self._scale_factor),
interpolation=cv2.INTER_CUBIC)
img_lr=cv2.resize(img_ds,(ncol, nrow),interpolation=cv2.INTER_CUBIC)
img=img.astype(npy.float32)
img_lr=img_lr.astype(npy.float32)
sub_img_lr=npy.zeros(self._provide_data[0][1],dtype=npy.float32)
sub_img_hr=npy.zeros(self._provide_label[0][1],dtype=npy.float32)
for i in range(self._crop_num):
nrow_start=npy.random.randint(0,nrow-crop_size)
ncol_start=npy.random.randint(0,ncol-crop_size)
img_crop=img_lr[nrow_start:nrow_start+crop_size,
ncol_start:ncol_start+crop_size,:]
img_crop=(img_crop-128) /128.0
img_crop = npy.swapaxes(img_crop, 0, 2)
img_crop = npy.swapaxes(img_crop, 1, 2)
sub_img_lr[i,:,:,:]=img_crop
img_crop=img[nrow_start:nrow_start+crop_size,
ncol_start:ncol_start+crop_size,:]
img_crop=(img_crop-128) /128.0
img_crop = npy.swapaxes(img_crop, 0, 2)
img_crop = npy.swapaxes(img_crop, 1, 2)
sub_img_hr[i,:,:,:]=img_crop
return SRDataBatch(sub_img_lr,sub_img_hr,0)
else:
raise StopIteration
评论列表
文章目录