def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
l = tf.shape(start_logits)[1]
masked_start_logits = exp_mask(start_logits, mask)
masked_end_logits = exp_mask(end_logits, mask)
# Explicit score for each span
span_scores = tf.expand_dims(start_logits, 2) + tf.expand_dims(end_logits, 1)
# Mask for in-bound spans, now (batch, start, end) matrix
mask = tf.sequence_mask(mask, l)
mask = tf.logical_and(tf.expand_dims(mask, 2), tf.expand_dims(mask, 1))
# Also mask out spans that are negative/inverse by taking only the upper triangle
mask = tf.matrix_band_part(mask, 0, self.bound)
# Apply the mask
mask = tf.cast(mask, tf.float32)
span_scores = span_scores * mask + (1 - mask) * VERY_NEGATIVE_NUMBER
if len(answer) == 1:
answer = answer[0]
span_scores = tf.reshape(span_scores, (tf.shape(start_logits)[0], -1))
answer = answer[:, 0] * l + answer[:, 1]
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=span_scores, labels=answer)
loss = tf.reduce_mean(losses)
else:
raise NotImplemented()
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
tf.nn.softmax(masked_end_logits),
masked_start_logits, masked_end_logits, mask)
评论列表
文章目录