def forward(self, input, target, save=True):
if save:
self.save_for_backward(input, target)
eps = 0.000001
_, result_ = input.max(1)
result_ = torch.squeeze(result_)
if input.is_cuda:
result = torch.cuda.FloatTensor(result_.size())
self.target_ = torch.cuda.FloatTensor(target.size())
else:
result = torch.FloatTensor(result_.size())
self.target_ = torch.FloatTensor(target.size())
result.copy_(result_)
self.target_.copy_(target)
target = self.target_
# print(input)
intersect = torch.dot(result, target)
# binary values so sum the same as sum of squares
result_sum = torch.sum(result)
target_sum = torch.sum(target)
union = result_sum + target_sum + (2*eps)
# the target volume can be empty - so we still want to
# end up with a score of 1 if the result is 0/0
IoU = intersect / union
print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
union, intersect, target_sum, result_sum, 2*IoU))
out = torch.FloatTensor(1).fill_(2*IoU)
self.intersect, self.union = intersect, union
return out
评论列表
文章目录