layers.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:deeppavlov 作者: deepmipt 项目源码 文件源码
def answer_end_pred(context_encoding, question_attention_vector, context_mask, answer_start_distribution, W, dropout_rate):
    """Answer end prediction layer."""

    # Answer end prediction depends on the start prediction
    def s_answer_feature(x):
        maxind = K.argmax(
            x,
            axis=1,
        )
        return maxind

    x = Lambda(lambda x: K.tf.cast(s_answer_feature(x), dtype=K.tf.int32))(answer_start_distribution)
    start_feature = Lambda(lambda arg: K.tf.gather_nd(arg[0], K.tf.stack(
        [tf.range(K.tf.shape(arg[1])[0]), tf.cast(arg[1], K.tf.int32)], axis=1)))([context_encoding, x])

    start_feature = Lambda(lambda q: repeat_vector(q[0], q[1]))([start_feature, context_encoding])

    # Answer end prediction
    answer_end = Lambda(lambda arg: concatenate([
        arg[0],
        arg[1],
        arg[2],
        multiply([arg[0], arg[1]]),
        multiply([arg[0], arg[2]])
    ]))([context_encoding, question_attention_vector, start_feature])

    answer_end = TimeDistributed(Dense(W, activation='relu'))(answer_end)
    answer_end = Dropout(rate=dropout_rate)(answer_end)
    answer_end = TimeDistributed(Dense(1))(answer_end)

    # apply masking
    answer_end = Lambda(lambda q: masked_softmax(q[0], q[1]))([answer_end, context_mask])
    answer_end = Lambda(lambda q: flatten(q))(answer_end)
    return answer_end
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号