def load(self, name):
image_path = os.path.join(self.dataset.image, '%s.jpg' % name)
label_path = os.path.join(self.dataset.layout_image, '%s.png' % name)
img = cv2.imread(image_path)
lbl = cv2.imread(label_path, 0)
img = cv2.resize(img, self.target_size, cv2.INTER_LINEAR)
lbl = cv2.resize(lbl, self.target_size, cv2.INTER_NEAREST)
img = self.transform(img)
lbl = np.clip(lbl, 1, 5) - 1
lbl = torch.from_numpy(np.expand_dims(lbl, axis=0)).long()
return img, lbl
评论列表
文章目录