def load_validation_data(vali_root):
alpha_dir = os.path.join(vali_root,'alpha')
RGB_dir = os.path.join(vali_root,'RGB')
images = os.listdir(alpha_dir)
test_num = len(images)
all_shape = []
rgb_batch = []
tri_batch = []
alp_batch = []
for i in range(test_num):
rgb = misc.imread(os.path.join(RGB_dir,images[i]))
alpha = misc.imread(os.path.join(alpha_dir,images[i]),'L')
trimap = generate_trimap(np.expand_dims(np.copy(alpha),2),np.expand_dims(alpha,2))[:,:,0]
alpha = alpha / 255.0
all_shape.append(trimap.shape)
rgb_batch.append(misc.imresize(rgb,[320,320,3])-g_mean)
trimap = misc.imresize(trimap,[320,320],interp = 'nearest').astype(np.float32)
tri_batch.append(np.expand_dims(trimap,2))
alp_batch.append(alpha)
return np.array(rgb_batch),np.array(tri_batch),np.array(alp_batch),all_shape,images
评论列表
文章目录