def __call__(self, query, previous_alignments):
"""Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
"""
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
processed_query = array_ops.expand_dims(processed_query, 1)
keys = self._keys
dtype = query.dtype
v = variable_scope.get_variable(
"attention_v", [self._num_units], dtype=dtype)
if self._normalize:
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
initializer=math.sqrt((1. / self._num_units)))
# normed_v = g * v / ||v||
normed_v = g * v * math_ops.rsqrt(
math_ops.reduce_sum(math_ops.square(v)))
score = math_ops.reduce_sum(
normed_v * math_ops.tanh(keys + processed_query + b), [2])
else:
score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query),
[2])
alignments = self._probability_fn(score, previous_alignments)
return alignments, self.mask_func(score)
attention_wrapper.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录