gru.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:theano-recurrence 作者: uyaseen 项目源码 文件源码
def generative_sampling(self, seed, emb_data, sample_length):
        fruit = theano.shared(value=seed)

        def step(h_tm, y_tm):

            x_z = T.dot(emb_data[y_tm], self.W_z) + self.b_z
            x_r = T.dot(emb_data[y_tm], self.W_r) + self.b_r
            x_h = T.dot(emb_data[y_tm], self.W) + self.b_h

            z_t = self.inner_activation(x_z + T.dot(h_tm, self.U_z))
            r_t = self.inner_activation(x_r + T.dot(h_tm, self.U_r))
            hh_t = self.activation(x_h + T.dot(r_t * h_tm, self.U))
            h_t = (T.ones_like(z_t) - z_t) * hh_t + z_t * h_tm

            y_t = T.nnet.softmax(T.dot(h_t, self.V) + self.b_y)
            y = T.argmax(y_t, axis=1)

            return h_t, y[0]

        [_, samples], _ = theano.scan(fn=step,
                                      outputs_info=[self.h0, fruit],
                                      n_steps=sample_length)

        get_samples = theano.function(inputs=[],
                                      outputs=samples)

        return get_samples()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号