attention_wrapper.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Question-Answering 作者: MurtyShikhar 项目源码 文件源码
def __call__(self, query, previous_alignments):
    """Score the query based on the keys and values.

    Args:
      query: Tensor of dtype matching `self.values` and shape
        `[batch_size, query_depth]`.
      previous_alignments: Tensor of dtype matching `self.values` and shape
        `[batch_size, alignments_size]`
        (`alignments_size` is memory's `max_time`).

    Returns:
      alignments: Tensor of dtype matching `self.values` and shape
        `[batch_size, alignments_size]` (`alignments_size` is memory's
        `max_time`).
    """
    with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
      processed_query = self.query_layer(query) if self.query_layer else query
      # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
      processed_query = array_ops.expand_dims(processed_query, 1)
      keys = self._keys
      dtype = query.dtype
      v = variable_scope.get_variable(
              "attention_v", [self._num_units], dtype=dtype)


      if self._normalize:
        # Scalar used in weight normalization
        g = variable_scope.get_variable(
            "attention_g", dtype=dtype,
            initializer=math.sqrt((1. / self._num_units)))

        # normed_v = g * v / ||v||
        normed_v = g * v * math_ops.rsqrt(
            math_ops.reduce_sum(math_ops.square(v)))
        score = math_ops.reduce_sum(
            normed_v * math_ops.tanh(keys + processed_query + b), [2])
      else:
        score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query),
                                    [2])

    alignments = self._probability_fn(score, previous_alignments)
    return alignments, self.mask_func(score)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号