def __call__(self, query, previous_alignments):
'''Score the query based on the keys and values.
Args:
query: Tensor of dtype matching `self.values` and shape
`[batch_size, query_depth]`.
previous_alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]`
(`alignments_size` is memory's `max_time`).
Returns:
alignments: Tensor of dtype matching `self.values` and shape
`[batch_size, alignments_size]` (`alignments_size` is memory's
`max_time`).
'''
with tf.variable_scope(None, 'bahdanau_attention', [query]):
processed_query = self.query_layer(
query) if self.query_layer else query
dtype = processed_query.dtype
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
processed_query = tf.expand_dims(processed_query, 1)
if FLAGS.use_conv_feat_att:
conv_feat = tf.nn.conv1d(
tf.expand_dims(previous_alignments, 2),
self.conv_filt, 1, 'SAME')
keys = self._keys
if self._normalize:
# normed_v = g * v / ||v||
normed_v = self.g * self.v * tf.rsqrt(
tf.reduce_sum(tf.square(self.v)))
score = tf.reduce_sum(
normed_v * tf.tanh(keys + processed_query + self.b), [2])
else:
if FLAGS.use_conv_feat_att:
score = tf.reduce_sum(self.v * tf.tanh(keys + processed_query + conv_feat),
[2])
else:
score = tf.reduce_sum(self.v * tf.tanh(keys + processed_query),
[2])
alignments = self._probability_fn(score, previous_alignments)
return alignments
评论列表
文章目录