attention.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def apply(self, is_train, x, mask=None):
        if self.key_mapper is not None:
            with tf.variable_scope("map_keys"):
                keys = self.key_mapper.apply(is_train, x, mask)
        else:
            keys = x

        weights = tf.get_variable("weights", keys.shape.as_list()[-1], dtype=tf.float32,
                                  initializer=get_keras_initialization(self.init))
        dist = tf.tensordot(keys, weights, axes=[[2], [0]])  # (batch, x_words)
        dist = exp_mask(dist, mask)
        dist = tf.nn.softmax(dist)

        out = tf.einsum("ajk,aj->ak", x, dist)  # (batch, x_dim)

        if self.post_process is not None:
            with tf.variable_scope("post_process"):
                out = self.post_process.apply(is_train, out)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号