def dice_error(input, target):
eps = 0.000001
_, result_ = input.max(1)
result_ = torch.squeeze(result_)
if input.is_cuda:
result = torch.cuda.FloatTensor(result_.size())
target_ = torch.cuda.FloatTensor(target.size())
else:
result = torch.FloatTensor(result_.size())
target_ = torch.FloatTensor(target.size())
result.copy_(result_.data)
target_.copy_(target.data)
target = target_
intersect = torch.dot(result, target)
result_sum = torch.sum(result)
target_sum = torch.sum(target)
union = result_sum + target_sum + 2*eps
intersect = np.max([eps, intersect])
# the target volume can be empty - so we still want to
# end up with a score of 1 if the result is 0/0
IoU = intersect / union
# print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
# union, intersect, target_sum, result_sum, 2*IoU))
return 2*IoU
评论列表
文章目录