attention.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def apply(self, is_train, x, x_mask=None):
        x_word_dim = tf.shape(x)[1]

        # (batch, x_word, key_word)
        dist_matrix = self.attention.get_scores(x, x)
        dist_matrix += tf.expand_dims(tf.eye(x_word_dim) * VERY_NEGATIVE_NUMBER, 0)  # Mask out self

        joint_mask = compute_attention_mask(x_mask, x_mask, x_word_dim, x_word_dim)
        if joint_mask is not None:
            dist_matrix += VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype))

        if not self.alignment_bias:
            select_probs = tf.nn.softmax(dist_matrix)
        else:
            # Allow zero-attention by adding a learned bias to the normalizer
            bias = tf.exp(tf.get_variable("no-alignment-bias", initializer=tf.constant(-1.0, dtype=tf.float32)))
            dist_matrix = tf.exp(dist_matrix)
            select_probs = dist_matrix / (tf.reduce_sum(dist_matrix, axis=2, keep_dims=True) + bias)

        response = tf.matmul(select_probs, x)  # (batch, x_words, q_dim)

        if self.merge is not None:
            with tf.variable_scope("merge"):
                response = self.merge.apply(is_train, response, x)
            return response
        else:
            return response
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号