def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
bound = self.bound
f1_weight = self.f1_weight
aggregate = self.aggregate
masked_logits1 = exp_mask(start_logits, mask)
masked_logits2 = exp_mask(end_logits, mask)
span_logits = []
for i in range(self.bound):
if i == 0:
span_logits.append(masked_logits1 + masked_logits2)
else:
span_logits.append(masked_logits1[:, :-i] + masked_logits2[:, i:])
span_logits = tf.concat(span_logits, axis=1)
l = tf.shape(start_logits)[1]
if len(answer) == 1:
answer = answer[0]
if answer.dtype == tf.int32:
if f1_weight == 0:
answer_ix = to_packed_coordinates(answer, l, bound)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(logits=span_logits, labels=answer_ix))
else:
f1_mask = packed_span_f1_mask(answer, l, bound)
if f1_weight < 1:
f1_mask *= f1_weight
f1_mask += (1 - f1_weight) * tf.one_hot(to_packed_coordinates(answer, l, bound), l)
# TODO can we stay in log space? (actually its tricky since f1_mask can have zeros...)
probs = tf.nn.softmax(span_logits)
loss = -tf.reduce_mean(tf.log(tf.reduce_sum(probs * f1_mask, axis=1)))
else:
log_norm = tf.reduce_logsumexp(span_logits, axis=1)
if aggregate == "sum":
log_score = tf.reduce_logsumexp(
span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
axis=1)
elif aggregate == "max":
log_score = tf.reduce_max(span_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer, tf.float32)),
axis=1)
else:
raise NotImplementedError()
loss = tf.reduce_mean(-(log_score - log_norm))
else:
raise NotImplementedError()
tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
return PackedSpanPrediction(span_logits, l, bound)
评论列表
文章目录