widgets.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:photinia 作者: XoriieInpottn 项目源码 文件源码
def _setup(self, seq, vec, activation=tf.nn.tanh):
        """Setup a soft attention mechanism for the given context sequence and state.
        The result is an attention context for the state.

        :param seq: The sequence tensor.
            Its shape is defined as (seq_length, batch_size, seq_elem_size).
        :param vec: The vector tensor.
            Its shape is defined as (batch_size, vec_size).
        :param activation: The activation function.
            Default is tf.nn.tanh.
        :return: An attention context with shape (batch_size, seq_elem_size).
        """
        #
        # (seq_length, batch_size, seq_elem_size) @ (seq_elem_size, common_size)
        # -> (seq_length, batch_size, common_size)
        a = tf.tensordot(seq, self._w, ((2,), (0,)))
        #
        # (batch_size, vec_size) @ (vec_size, common_size)
        # -> (batch_size, common_size)
        # -> (1, batch_size, common_size)
        b = tf.matmul(vec, self._u)
        b = tf.reshape(b, (1, -1, self._common_size))
        #
        # -> (seq_length, batch_size, common_size)
        # (seq_length, batch_size, common_size) @ (common_size, 1)
        # -> (seq_length, batch_size, 1)
        a = activation(a + b) if activation is not None else a + b
        a = tf.tensordot(a, self._omega, ((2,), (0,)))
        a = tf.nn.softmax(a, dim=0)
        #
        # (seq_length, batch_size, 1) * (seq_length, batch_size, seq_elem_size)
        # -> (seq_length, batch_size, seq_elem_size)
        # -> (batch_size, seq_elem_size)
        att_context = tf.reduce_sum(a * seq, 0)
        return att_context
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号