def answer_end_pred(context_encoding, question_attention_vector, context_mask, answer_start_distribution, W, dropout_rate):
"""Answer end prediction layer."""
# Answer end prediction depends on the start prediction
def s_answer_feature(x):
maxind = K.argmax(
x,
axis=1,
)
return maxind
x = Lambda(lambda x: K.tf.cast(s_answer_feature(x), dtype=K.tf.int32))(answer_start_distribution)
start_feature = Lambda(lambda arg: K.tf.gather_nd(arg[0], K.tf.stack(
[tf.range(K.tf.shape(arg[1])[0]), tf.cast(arg[1], K.tf.int32)], axis=1)))([context_encoding, x])
start_feature = Lambda(lambda q: repeat_vector(q[0], q[1]))([start_feature, context_encoding])
# Answer end prediction
answer_end = Lambda(lambda arg: concatenate([
arg[0],
arg[1],
arg[2],
multiply([arg[0], arg[1]]),
multiply([arg[0], arg[2]])
]))([context_encoding, question_attention_vector, start_feature])
answer_end = TimeDistributed(Dense(W, activation='relu'))(answer_end)
answer_end = Dropout(rate=dropout_rate)(answer_end)
answer_end = TimeDistributed(Dense(1))(answer_end)
# apply masking
answer_end = Lambda(lambda q: masked_softmax(q[0], q[1]))([answer_end, context_mask])
answer_end = Lambda(lambda q: flatten(q))(answer_end)
return answer_end
评论列表
文章目录