def dot_product_attention(q,
k,
v,
bias,
dropout_rate=0.0,
image_shapes=None,
name=None,
make_image_summary=True):
"""dot-product attention.
Args:
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
dropout_rate: a floating point number
image_shapes: optional tuple of integer scalars.
see comments for attention_image_summary()
name: an optional string
make_image_summary: True if you want an image summary.
Returns:
A Tensor.
"""
with tf.variable_scope(
name, default_name="dot_product_attention", values=[q, k, v]):
logits = tf.matmul(q, k, transpose_b=True)
if bias is not None:
logits += bias
weights = tf.nn.softmax(logits, name="attention_weights")
weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
return tf.matmul(weights, v)
评论列表
文章目录