def EPE(input_flow, target_flow, sparse=False, mean=True):
EPE_map = torch.norm(target_flow-input_flow,2,1)
if sparse:
EPE_map = EPE_map[target_flow != 0]
if mean:
return EPE_map.mean()
else:
return EPE_map.sum()
评论列表
文章目录