tree_gru.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tree_rnn 作者: ofirnachum 项目源码 文件源码
def create_recursive_unit(self):
        self.W_z = theano.shared(self.init_matrix([self.hidden_dim, self.emb_dim]))
        self.U_z = theano.shared(self.init_matrix(
            [self.degree, self.hidden_dim, self.hidden_dim]))
        self.W_r = theano.shared(self.init_matrix([self.hidden_dim, self.emb_dim]))
        self.U_r = theano.shared(self.init_matrix(
            [self.degree, self.hidden_dim, self.hidden_dim]))
        self.W_h = theano.shared(self.init_matrix([self.hidden_dim, self.emb_dim]))
        self.U_h = theano.shared(self.init_matrix([self.hidden_dim, self.hidden_dim]))
        self.params.extend([
            self.W_z, self.U_z,
            self.W_r, self.U_r,
            self.W_h, self.U_h])

        def unit(parent_x, child_h, child_exists):
            (pre_z, pre_r), _ = theano.map(
                fn=lambda Uz, Ur, h: (T.dot(Uz, h), T.dot(Ur, h)),
                sequences=[self.U_z, self.U_r, child_h])

            z = _softmax(
                T.dot(self.W_z, parent_x).dimshuffle('x', 0) + pre_z,
                child_exists, add_one=True)
            r = _softmax(
                T.dot(self.W_r, parent_x).dimshuffle('x', 0) + pre_r,
                child_exists, add_one=False)
            h_hat = T.tanh(T.dot(self.W_h, parent_x) +
                           T.dot(self.U_h, T.sum(r * child_h, axis=0)))
            h = (1 - T.sum(z, axis=0)) * h_hat + T.sum(z * child_h, axis=0)
            return h

        return unit
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号