def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
masked_start_logits = exp_mask(start_logits, mask)
masked_end_logits = exp_mask(end_logits, mask)
batch_dim = tf.shape(start_logits)[0]
if len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
none_logit = tf.get_variable("none-logit", initializer=self.non_init, dtype=tf.float32)
none_logit = tf.tile(tf.expand_dims(none_logit, 0), [batch_dim])
all_logits = tf.reshape(tf.expand_dims(masked_start_logits, 1) +
tf.expand_dims(masked_end_logits, 2),
(batch_dim, -1))
# (batch, (l * l) + 1) logits including the none option
all_logits = tf.concat([all_logits, tf.expand_dims(none_logit, 1)], axis=1)
log_norms = tf.reduce_logsumexp(all_logits, axis=1)
# Now build a "correctness" mask in the same format
correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2))
correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
correct_mask = tf.concat([correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True))],
axis=1)
log_correct = tf.reduce_logsumexp(
all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1)
loss = tf.reduce_mean(-(log_correct - log_norms))
probs = tf.nn.softmax(all_logits)
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits,
probs[:, -1], none_logit)
else:
raise NotImplemented()
评论列表
文章目录