model.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:qrn 作者: uwnlp 项目源码 文件源码
def __call__(self, u_t, a, b, scope=None):
        """

        :param u_t: [N, M, d]
        :param a: [N, M. d]
        :param b: [N, M. d]
        :param mask:  [N, M]
        :return:
        """
        N, M, d = self.batch_size, self.mem_size, self.hidden_size
        L, sL = self.L, self.sL
        with tf.name_scope(scope or self.__class__.__name__):
            L = tf.tile(tf.expand_dims(tf.expand_dims(L, 0), 0), [N, d, 1, 1])
            sL = tf.tile(tf.expand_dims(tf.expand_dims(sL, 0), 0), [N, d, 1, 1])
            logb = tf.log(b + 1e-9)  # [N, M, d]
            logb = tf.concat(1, [tf.zeros([N, 1, d]), tf.slice(logb, [0, 1, 0], [-1, -1, -1])])  # [N, M, d]
            logb = tf.expand_dims(tf.transpose(logb, [0, 2, 1]), -1)  # [N, d, M, 1]
            left = L * tf.exp(tf.batch_matmul(L, logb * sL))  # [N, d, M, M]
            right = a * u_t  # [N, M, d]
            right = tf.expand_dims(tf.transpose(right, [0, 2, 1]), -1)  # [N, d, M, 1]
            u = tf.batch_matmul(left, right)  # [N, d, M, 1]
            u = tf.transpose(tf.squeeze(u, [3]), [0, 2, 1])  # [N, M, d]
        return u
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号