def interaction_layer(seq1, seq1_length, seq2, seq2_length, seq1_to_seq2,
module='attention_matching', attn_type='bilinear_diagonal', scaled=True, with_sentinel=False,
name='interaction_layer', reuse=False, num_layers=1, encoder=None, concat=True, **kwargs):
with tf.variable_scope(name, reuse=reuse):
if seq1_to_seq2 is not None:
seq2 = tf.gather(seq2, seq1_to_seq2)
seq2_length = tf.gather(seq2_length, seq1_to_seq2)
if module == 'attention_matching':
out = attention_matching_layer(seq1, seq1_length, seq2, seq2_length,
attn_type, scaled, with_sentinel)
elif module == 'bidaf':
out = bidaf_layer(seq1, seq1_length, seq2, seq2_length)
elif module == 'coattention':
out = coattention_layer(
seq1, seq1_length, seq2, seq2_length, attn_type, scaled, with_sentinel, num_layers, encoder)
else:
raise ValueError("Unknown interaction type: %s" % module)
if concat:
out = tf.concat([seq1, out], 2)
return out
评论列表
文章目录