special_fn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def dot_product_attention(q,
                          k,
                          v,
                          bias,
                          dropout_rate=0.0,
                          image_shapes=None,
                          name=None,
                          make_image_summary=True):
    """dot-product attention.

    Args:
      q: a Tensor with shape [batch, heads, length_q, depth_k]
      k: a Tensor with shape [batch, heads, length_kv, depth_k]
      v: a Tensor with shape [batch, heads, length_kv, depth_v]
      bias: bias Tensor (see attention_bias())
      dropout_rate: a floating point number
      image_shapes: optional tuple of integer scalars.
        see comments for attention_image_summary()
      name: an optional string
      make_image_summary: True if you want an image summary.

    Returns:
      A Tensor.
    """
    with tf.variable_scope(
            name, default_name="dot_product_attention", values=[q, k, v]):
        logits = tf.matmul(q, k, transpose_b=True)
        if bias is not None:
            logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        return tf.matmul(weights, v)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号