def simple_attention(target, source, bias=None):
"""A simple attention function.
Args:
target: a `Tensor` with shape `[batch, target_timesteps, depth]` or
`[batch, target_timesteps_1, target_timesteps_2, depth]`
source: a `Tensor` with shape `[batch, source_timesteps, depth]` or
`[batch, source_timesteps_1, source_timesteps_2, depth]`
bias: an optional `Tensor` with shape `[batch, timesteps, 1, 1]` used
to mask the attention to not attend to padding of input.
Returns:
a `Tensor` with same shape as `target`
"""
with tf.name_scope("simple_attention", [target, source]):
target_shape = shape_list(target)
source_shape = shape_list(source)
target = tf.reshape(
target,
[target_shape[0], target_shape[1] * target_shape[2], target_shape[3]])
source = tf.reshape(
source,
[source_shape[0], source_shape[1] * source_shape[2], source_shape[3]])
attention = tf.matmul(target, source, transpose_b=True)
attention *= tf.rsqrt(tf.to_float(shape_list(target)[2]))
if bias is not None:
attention += tf.expand_dims(tf.squeeze(bias, axis=[2, 3]), axis=1)
attention = tf.nn.softmax(attention)
if not tf.get_variable_scope().reuse:
tf.summary.image("attention", tf.expand_dims(attention, 3), max_outputs=5)
attended = tf.matmul(attention, source)
return tf.reshape(attended, target_shape)
评论列表
文章目录