attention.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def apply_attention(attn_scores, states, length, is_self=False, with_sentinel=True, reuse=False):
    attn_scores += tf.expand_dims(misc.mask_for_lengths(length, tf.shape(attn_scores)[2]), 1)
    if is_self:
        # exclude attending to state itself
        attn_scores += tf.expand_dims(tf.diag(tf.fill([tf.shape(attn_scores)[1]], -1e6)), 0)
    if with_sentinel:
        with tf.variable_scope('sentinel', reuse=reuse):
            s = tf.get_variable('score', [1, 1, 1], tf.float32, tf.zeros_initializer())
        s = tf.tile(s, [tf.shape(attn_scores)[0], tf.shape(attn_scores)[1], 1])
        attn_probs = tf.nn.softmax(tf.concat([s, attn_scores], 2))
        attn_probs = attn_probs[:, :, 1:]
    else:
        attn_probs = tf.nn.softmax(attn_scores)
    attn_states = tf.einsum('abd,adc->abc', attn_probs, states)
    return attn_scores, attn_probs, attn_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号