interaction_layer.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def interaction_layer(seq1, seq1_length, seq2, seq2_length, seq1_to_seq2,
                      module='attention_matching', attn_type='bilinear_diagonal', scaled=True, with_sentinel=False,
                      name='interaction_layer', reuse=False, num_layers=1, encoder=None, concat=True, **kwargs):
    with tf.variable_scope(name, reuse=reuse):
        if seq1_to_seq2 is not None:
            seq2 = tf.gather(seq2, seq1_to_seq2)
            seq2_length = tf.gather(seq2_length, seq1_to_seq2)
        if module == 'attention_matching':
            out = attention_matching_layer(seq1, seq1_length, seq2, seq2_length,
                                           attn_type, scaled, with_sentinel)
        elif module == 'bidaf':
            out = bidaf_layer(seq1, seq1_length, seq2, seq2_length)
        elif module == 'coattention':
            out = coattention_layer(
                seq1, seq1_length, seq2, seq2_length, attn_type, scaled, with_sentinel, num_layers, encoder)
        else:
            raise ValueError("Unknown interaction type: %s" % module)

    if concat:
        out = tf.concat([seq1, out], 2)
    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号