def pose_loss(input, target): """Gets l2 loss between input and target""" x = torch.norm(input-target, dim=1) x = torch.mean(x) return x