def get_predicted_antecedents(self, antecedents, antecedent_scores):
"""
Forms a list of predicted antecedent labels
Args:
antecedents: [] get from C++ function
antecedent_scores: [num_mentions, max_ant + 1] output of fully-connected network
that compute antecedent_scores
Returns: a list of predicted antecedent labels
"""
predicted_antecedents = []
for i, index in enumerate(np.argmax(antecedent_scores, axis=1) - 1):
if index < 0:
predicted_antecedents.append(-1)
else:
predicted_antecedents.append(antecedents[i, index])
return predicted_antecedents
评论列表
文章目录