attention.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:deepsphinx 作者: vagrawal 项目源码 文件源码
def __call__(self, query, previous_alignments):
        '''Score the query based on the keys and values.
        Args:
          query: Tensor of dtype matching `self.values` and shape
            `[batch_size, query_depth]`.
          previous_alignments: Tensor of dtype matching `self.values` and shape
            `[batch_size, alignments_size]`
            (`alignments_size` is memory's `max_time`).
        Returns:
          alignments: Tensor of dtype matching `self.values` and shape
            `[batch_size, alignments_size]` (`alignments_size` is memory's
            `max_time`).
        '''
        with tf.variable_scope(None, 'bahdanau_attention', [query]):
            processed_query = self.query_layer(
                query) if self.query_layer else query
            dtype = processed_query.dtype
            # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
            processed_query = tf.expand_dims(processed_query, 1)
            if FLAGS.use_conv_feat_att:
                conv_feat = tf.nn.conv1d(
                        tf.expand_dims(previous_alignments, 2),
                        self.conv_filt, 1, 'SAME')
            keys = self._keys
            if self._normalize:
                # normed_v = g * v / ||v||
                normed_v = self.g * self.v * tf.rsqrt(
                    tf.reduce_sum(tf.square(self.v)))
                score = tf.reduce_sum(
                    normed_v * tf.tanh(keys + processed_query + self.b), [2])
            else:
                if FLAGS.use_conv_feat_att:
                    score = tf.reduce_sum(self.v * tf.tanh(keys + processed_query + conv_feat),
                                      [2])
                else:
                    score = tf.reduce_sum(self.v * tf.tanh(keys + processed_query),
                                      [2])


        alignments = self._probability_fn(score, previous_alignments)
        return alignments
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号