def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
masked_start_logits = exp_mask(start_logits, mask)
masked_end_logits = exp_mask(end_logits, mask)
if len(answer) == 1:
# answer span is encoding in a sparse int array
answer_spans = answer[0]
losses1 = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=masked_start_logits, labels=answer_spans[:, 0])
losses2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=masked_end_logits, labels=answer_spans[:, 1])
loss = tf.add_n([tf.reduce_mean(losses1), tf.reduce_mean(losses2)], name="loss")
elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
# all correct start/end bounds are marked in a dense bool array
# In this case there might be multiple answer spans, so we need an aggregation strategy
losses = []
for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
log_norm = tf.reduce_logsumexp(logits, axis=1)
if self.aggregate == "sum":
log_score = tf.reduce_logsumexp(logits +
VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)),
axis=1)
elif self.aggregate == "max":
log_score = tf.reduce_max(logits +
VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), axis=1)
else:
raise ValueError()
losses.append(tf.reduce_mean(-(log_score - log_norm)))
loss = tf.add_n(losses)
else:
raise NotImplemented()
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
tf.nn.softmax(masked_end_logits),
masked_start_logits, masked_end_logits, mask)
评论列表
文章目录