def get_gt(gt, n_classes, downsize=False):
if not downsize:
return gt
original_shape = gt.shape
gt_onehot = np.reshape(gt, (-1,))
gt_onehot = np.reshape(one_hot(gt_onehot, n_classes), original_shape + (n_classes,))
gt_onehot = np.transpose(gt_onehot, (3, 0, 1, 2))
zoom_gt = np.array([zoom(class_map, 0.5, order=1) for class_map in gt_onehot])
zoom_gt = zoom_gt.argmax(axis=0)
zoom_gt = np.asarray(zoom_gt, dtype='int8')
return zoom_gt
评论列表
文章目录