span_prediction.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        l = tf.shape(start_logits)[1]
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)

        # Explicit score for each span
        span_scores = tf.expand_dims(start_logits, 2) + tf.expand_dims(end_logits, 1)

        # Mask for in-bound spans, now (batch, start, end) matrix
        mask = tf.sequence_mask(mask, l)
        mask = tf.logical_and(tf.expand_dims(mask, 2), tf.expand_dims(mask, 1))

        # Also mask out spans that are negative/inverse by taking only the upper triangle
        mask = tf.matrix_band_part(mask, 0, self.bound)

        # Apply the mask
        mask = tf.cast(mask, tf.float32)
        span_scores = span_scores * mask + (1 - mask) * VERY_NEGATIVE_NUMBER

        if len(answer) == 1:
            answer = answer[0]
            span_scores = tf.reshape(span_scores, (tf.shape(start_logits)[0], -1))
            answer = answer[:, 0] * l + answer[:, 1]
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=span_scores, labels=answer)
            loss = tf.reduce_mean(losses)
        else:
            raise NotImplemented()
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
        return BoundaryPrediction(tf.nn.softmax(masked_start_logits),
                                  tf.nn.softmax(masked_end_logits),
                                  masked_start_logits, masked_end_logits, mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号