attention.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def apply(self, is_train, x, mask=None):
        if self.key_mapper is not None:
            with tf.variable_scope("map_keys"):
                keys = self.key_mapper.apply(is_train, x, mask)
        else:
            keys = x

        weights = tf.get_variable("weights", (keys.shape.as_list()[-1], self.n_encodings), dtype=tf.float32,
                                  initializer=get_keras_initialization(self.init))
        dist = tf.tensordot(keys, weights, axes=[[2], [0]])  # (batch, x_words, n_encoding)
        if self.bias:
            dist += tf.get_variable("bias", (1, 1, self.n_encodings),
                                    dtype=tf.float32, initializer=tf.zeros_initializer())
        if mask is not None:
            bool_mask = tf.expand_dims(tf.cast(tf.sequence_mask(mask, tf.shape(x)[1]), tf.float32), 2)
            dist = bool_mask * bool_mask + (1 - bool_mask) * VERY_NEGATIVE_NUMBER

        dist = tf.nn.softmax(dist, dim=1)

        out = tf.einsum("ajk,ajn->ank", x, dist)  # (batch, n_encoding, feature)

        if self.post_process is not None:
            with tf.variable_scope("post_process"):
                out = self.post_process.apply(is_train, out)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号