def get_mean_error(ret_nodes, nodes, assumedNodesPresent, trueNodesPresent):
'''
Computes average displacement error
Parameters
==========
ret_nodes : A tensor of shape pred_length x numNodes x 2
Contains the predicted positions for the nodes
nodes : A tensor of shape pred_length x numNodes x 2
Contains the true positions for the nodes
nodesPresent : A list of lists, of size pred_length
Each list contains the nodeIDs of the nodes present at that time-step
Returns
=======
Error : Mean euclidean distance between predicted trajectory and the true trajectory
'''
pred_length = ret_nodes.size()[0]
error = torch.zeros(pred_length).cuda()
counter = 0
for tstep in range(pred_length):
counter = 0
for nodeID in assumedNodesPresent:
if nodeID not in trueNodesPresent[tstep]:
continue
pred_pos = ret_nodes[tstep, nodeID, :]
true_pos = nodes[tstep, nodeID, :]
error[tstep] += torch.norm(pred_pos - true_pos, p=2)
counter += 1
if counter != 0:
error[tstep] = error[tstep] / counter
return torch.mean(error)
评论列表
文章目录