def get_final_error(ret_nodes, nodes, assumedNodesPresent, trueNodesPresent):
'''
Computes final displacement error
Parameters
==========
ret_nodes : A tensor of shape pred_length x numNodes x 2
Contains the predicted positions for the nodes
nodes : A tensor of shape pred_length x numNodes x 2
Contains the true positions for the nodes
nodesPresent : A list of lists, of size pred_length
Each list contains the nodeIDs of the nodes present at that time-step
Returns
=======
Error : Mean final euclidean distance between predicted trajectory and the true trajectory
'''
pred_length = ret_nodes.size()[0]
error = 0
counter = 0
# Last time-step
tstep = pred_length - 1
for nodeID in assumedNodesPresent:
if nodeID not in trueNodesPresent[tstep]:
continue
pred_pos = ret_nodes[tstep, nodeID, :]
true_pos = nodes[tstep, nodeID, :]
error += torch.norm(pred_pos - true_pos, p=2)
counter += 1
if counter != 0:
error = error / counter
return error
评论列表
文章目录