span_prediction.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def predict(self, answer, start_logits, end_logits, mask) -> Prediction:
        masked_start_logits = exp_mask(start_logits, mask)
        masked_end_logits = exp_mask(end_logits, mask)
        batch_dim = tf.shape(start_logits)[0]

        if len(answer) == 2 and all(x.dtype == tf.bool for x in answer):
            none_logit = tf.get_variable("none-logit", initializer=self.non_init, dtype=tf.float32)
            none_logit = tf.tile(tf.expand_dims(none_logit, 0), [batch_dim])

            all_logits = tf.reshape(tf.expand_dims(masked_start_logits, 1) +
                                    tf.expand_dims(masked_end_logits, 2),
                                    (batch_dim, -1))

            # (batch, (l * l) + 1) logits including the none option
            all_logits = tf.concat([all_logits, tf.expand_dims(none_logit, 1)], axis=1)
            log_norms = tf.reduce_logsumexp(all_logits, axis=1)

            # Now build a "correctness" mask in the same format
            correct_mask = tf.logical_and(tf.expand_dims(answer[0], 1), tf.expand_dims(answer[1], 2))
            correct_mask = tf.reshape(correct_mask, (batch_dim, -1))
            correct_mask = tf.concat([correct_mask, tf.logical_not(tf.reduce_any(answer[0], axis=1, keep_dims=True))],
                                     axis=1)

            log_correct = tf.reduce_logsumexp(
                all_logits + VERY_NEGATIVE_NUMBER * (1 - tf.cast(correct_mask, tf.float32)), axis=1)
            loss = tf.reduce_mean(-(log_correct - log_norms))
            probs = tf.nn.softmax(all_logits)
            tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
            return ConfidencePrediction(probs[:, :-1], masked_start_logits, masked_end_logits,
                                        probs[:, -1], none_logit)
        else:
            raise NotImplemented()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号