def additive_attention(queries, keys, values, bias, hidden_size, concat=False,
keep_prob=None, dtype=None, scope=None):
""" Additive attention mechanism. This layer is implemented using a
one layer feed forward neural network
:param queries: A tensor with shape [batch, heads, length_q, depth_k]
:param keys: A tensor with shape [batch, heads, length_kv, depth_k]
:param values: A tensor with shape [batch, heads, length_kv, depth_v]
:param bias: A tensor
:param hidden_size: An integer
:param concat: A boolean value. If ``concat'' is set to True, then
the computation of attention mechanism is following $tanh(W[q, k])$.
When ``concat'' is set to False, the computation is following
$tanh(Wq + Vk)$
:param keep_prob: a scalar in [0, 1]
:param dtype: An optional instance of tf.DType
:param scope: An optional string, the scope of this layer
:returns: A dict with the following keys:
weights: A tensor with shape [batch, length_q]
outputs: A tensor with shape [batch, length_q, depth_v]
"""
with tf.variable_scope(scope, default_name="additive_attention",
values=[queries, keys, values, bias], dtype=dtype):
length_q = tf.shape(queries)[2]
length_kv = tf.shape(keys)[2]
q = tf.tile(tf.expand_dims(queries, 3), [1, 1, 1, length_kv, 1])
k = tf.tile(tf.expand_dims(keys, 2), [1, 1, length_q, 1, 1])
if concat:
combined = tf.tanh(linear(tf.concat([q, k], axis=-1), hidden_size,
True, True, name="qk_transform"))
else:
q = linear(queries, hidden_size, True, True, name="q_transform")
k = linear(keys, hidden_size, True, True, name="key_transform")
combined = tf.tanh(q + k)
# shape: [batch, heads, length_q, length_kv]
logits = tf.squeeze(linear(combined, 1, True, True, name="logits"),
axis=-1)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
if keep_prob or keep_prob < 1.0:
weights = tf.nn.dropout(weights, keep_prob)
outputs = tf.matmul(weights, values)
return {"weights": weights, "outputs": outputs}
评论列表
文章目录