def apply_attention(attn_scores, states, length, is_self=False, with_sentinel=True, reuse=False):
attn_scores += tf.expand_dims(misc.mask_for_lengths(length, tf.shape(attn_scores)[2]), 1)
if is_self:
# exclude attending to state itself
attn_scores += tf.expand_dims(tf.diag(tf.fill([tf.shape(attn_scores)[1]], -1e6)), 0)
if with_sentinel:
with tf.variable_scope('sentinel', reuse=reuse):
s = tf.get_variable('score', [1, 1, 1], tf.float32, tf.zeros_initializer())
s = tf.tile(s, [tf.shape(attn_scores)[0], tf.shape(attn_scores)[1], 1])
attn_probs = tf.nn.softmax(tf.concat([s, attn_scores], 2))
attn_probs = attn_probs[:, :, 1:]
else:
attn_probs = tf.nn.softmax(attn_scores)
attn_states = tf.einsum('abd,adc->abc', attn_probs, states)
return attn_scores, attn_probs, attn_states
评论列表
文章目录