def sub_load_data(data, img_size, aug):
img_name, dataset = data
img = misc.imread(dataset+'images/'+img_name+'.bmp', mode='L')
seg = misc.imread(dataset+'seg_labels/'+img_name+'.png', mode='L')
try:
ali = misc.imread(dataset+'ori_labels/'+img_name+'.bmp', mode='L')
except:
ali = np.zeros_like(img)
mnt = np.array(mnt_reader(dataset+'mnt_labels/'+img_name+'.mnt'), dtype=float)
if any(img.shape != img_size):
# random pad mean values to reach required shape
if np.random.rand()<aug:
tra = np.int32(np.random.rand(2)*(np.array(img_size)-np.array(img.shape)))
else:
tra = np.int32(0.5*(np.array(img_size)-np.array(img.shape)))
img_t = np.ones(img_size)*np.mean(img)
seg_t = np.zeros(img_size)
ali_t = np.ones(img_size)*np.mean(ali)
img_t[tra[0]:tra[0]+img.shape[0],tra[1]:tra[1]+img.shape[1]] = img
seg_t[tra[0]:tra[0]+img.shape[0],tra[1]:tra[1]+img.shape[1]] = seg
ali_t[tra[0]:tra[0]+img.shape[0],tra[1]:tra[1]+img.shape[1]] = ali
img = img_t
seg = seg_t
ali = ali_t
mnt = mnt+np.array([tra[1],tra[0],0])
if np.random.rand()<aug:
# random rotation [0 - 360] & translation img_size / 4
rot = np.random.rand() * 360
tra = (np.random.rand(2)-0.5) / 2 * img_size
img = ndimage.rotate(img, rot, reshape=False, mode='reflect')
img = ndimage.shift(img, tra, mode='reflect')
seg = ndimage.rotate(seg, rot, reshape=False, mode='constant')
seg = ndimage.shift(seg, tra, mode='constant')
ali = ndimage.rotate(ali, rot, reshape=False, mode='reflect')
ali = ndimage.shift(ali, tra, mode='reflect')
mnt_r = point_rot(mnt[:, :2], rot/180*np.pi, img.shape, img.shape)
mnt = np.column_stack((mnt_r+tra[[1, 0]], mnt[:, 2]-rot/180*np.pi))
# only keep mnt that stay in pic & not on border
mnt = mnt[(8<=mnt[:,0])*(mnt[:,0]<img_size[1]-8)*(8<=mnt[:, 1])*(mnt[:,1]<img_size[0]-8), :]
return img, seg, ali, mnt
评论列表
文章目录