def validate(model):
dice_coefs = []
for image_path, label_path in zip(df_val["image"], df_val["label"]):
image = load_nifti(image_path)
label = load_nifti(label_path)
centers = [[], [], []]
for img_len, len_out, center, n_tile in zip(image.shape, args.output_shape, centers, args.n_tiles):
assert img_len < len_out * n_tile, "{} must be smaller than {} x {}".format(img_len, len_out, n_tile)
stride = int((img_len - len_out) / (n_tile - 1))
center.append(len_out / 2)
for i in range(n_tile - 2):
center.append(center[-1] + stride)
center.append(img_len - len_out / 2)
output = np.zeros((dataset["n_classes"],) + image.shape[:-1])
for x, y, z in itertools.product(*centers):
patch = crop_patch(image, [x, y, z], args.input_shape)
patch = np.expand_dims(patch, 0)
patch = xp.asarray(patch)
slices_out = [slice(center - len_out / 2, center + len_out / 2) for len_out, center in zip(args.output_shape, [x, y, z])]
slices_in = [slice((len_in - len_out) / 2, len_in - (len_in - len_out) / 2) for len_out, len_in, in zip(args.output_shape, args.input_shape)]
output[slice(None), slices_out[0], slices_out[1], slices_out[2]] += chainer.cuda.to_cpu(model(patch).data[0, slice(None), slices_in[0], slices_in[1], slices_in[2]])
y = np.argmax(output, axis=0).astype(np.int32)
dice_coefs.append(dice_coefficients(y, label, labels=range(dataset["n_classes"])))
dice_coefs = np.array(dice_coefs)
return np.mean(dice_coefs, axis=0)
评论列表
文章目录