def _setup(self, seq, vec, activation=tf.nn.tanh):
"""Setup a soft attention mechanism for the given context sequence and state.
The result is an attention context for the state.
:param seq: The sequence tensor.
Its shape is defined as (seq_length, batch_size, seq_elem_size).
:param vec: The vector tensor.
Its shape is defined as (batch_size, vec_size).
:param activation: The activation function.
Default is tf.nn.tanh.
:return: An attention context with shape (batch_size, seq_elem_size).
"""
#
# (seq_length, batch_size, seq_elem_size) @ (seq_elem_size, common_size)
# -> (seq_length, batch_size, common_size)
a = tf.tensordot(seq, self._w, ((2,), (0,)))
#
# (batch_size, vec_size) @ (vec_size, common_size)
# -> (batch_size, common_size)
# -> (1, batch_size, common_size)
b = tf.matmul(vec, self._u)
b = tf.reshape(b, (1, -1, self._common_size))
#
# -> (seq_length, batch_size, common_size)
# (seq_length, batch_size, common_size) @ (common_size, 1)
# -> (seq_length, batch_size, 1)
a = activation(a + b) if activation is not None else a + b
a = tf.tensordot(a, self._omega, ((2,), (0,)))
a = tf.nn.softmax(a, dim=0)
#
# (seq_length, batch_size, 1) * (seq_length, batch_size, seq_elem_size)
# -> (seq_length, batch_size, seq_elem_size)
# -> (batch_size, seq_elem_size)
att_context = tf.reduce_sum(a * seq, 0)
return att_context
评论列表
文章目录