span_prediction.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        if len(answer) == 1:
            # answer span is encoding in a sparse int array
            answer_spans = answer[0]
            losses1 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_start_logits, labels=answer_spans[:, 0])
            losses2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=masked_end_logits, labels=answer_spans[:, 1])
            loss = tf.add_n([tf.reduce_mean(losses1), tf.reduce_mean(losses2)], name="loss")
        elif len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            # all correct start/end bounds are marked in a dense bool array
            # In this case there might be multiple answer spans, so we need an aggregation strategy
            losses = []
            for answer_mask, logits in zip(answer, [masked_start_logits, masked_end_logits]):
                log_norm = tf.reduce_logsumexp(logits, axis=1)
                if self.aggregate == "sum":
                    log_score = tf.reduce_logsumexp(logits +
                                                    VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)),
                                                    axis=1)
                elif self.aggregate == "max":
                    log_score = tf.reduce_max(logits +
                                              VERY_NEGATIVE_NUMBER * (1 - tf.cast(answer_mask, tf.float32)), axis=1)
                else:
                    raise ValueError()
                losses.append(tf.reduce_mean(-(log_score - log_norm)))
            loss = tf.add_n(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  tf.nn.softmax(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号