def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
masked_start_logits = exp_mask(start_logits, mask)
masked_end_logits = exp_mask(end_logits, mask)
if len(answer) == 3:
group_ids = answer[2]
# Turn the ids into segment ids using tf.unique
_, group_segments = tf.unique(group_ids, out_idx=tf.int32)
losses = []
for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
group_norms = segment_logsumexp(logits, group_segments)
if self.aggregate == "sum":
log_score = segment_logsumexp(logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)),
group_segments)
else:
raise ValueError()
losses.append(tf.reduce_mean(-(log_score - group_norms)))
loss = tf.add_n(losses)
else:
raise NotImplemented()
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
tf.nn.softmax(masked_end_logits),
masked_start_logits, masked_end_logits, mask)
评论列表
文章目录