span_prediction.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def apply(self, is_train, context_embed, answer, context_mask=None):
        init_fn = get_keras_initialization(self.init)
        with tf.variable_scope("bounds_encoding"):
            m1, m2 = self.predictor.apply(is_train, context_embed, context_mask)

        with tf.variable_scope("start_pred"):
            logits1 = fully_connected(m1, 1, activation_fn=None,
                                      weights_initializer=init_fn)
            logits1 = tf.squeeze(logits1, squeeze_dims=[2])

        with tf.variable_scope("end_pred"):
            logits2 = fully_connected(m2, 1, activation_fn=None, weights_initializer=init_fn)
            logits2 = tf.squeeze(logits2, squeeze_dims=[2])

        with tf.variable_scope("predict_span"):
            return self.span_predictor.predict(answer, logits1, logits2, context_mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号