rnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:theano-recurrence 作者: uyaseen 项目源码 文件源码
def generative_sampling(self, seed, emb_data, sample_length):
        fruit = theano.shared(value=seed)

        def step(h_tm, y_tm):
            h_t = self.activation(T.dot(emb_data[y_tm], self.W) +
                                  T.dot(h_tm, self.U) + self.bh)
            y_t = T.nnet.softmax(T.dot(h_t, self.V) + self.by)
            y = T.argmax(y_t, axis=1)

            return h_t, y[0]

        [_, samples], _ = theano.scan(fn=step,
                                      outputs_info=[self.h0, fruit],
                                      n_steps=sample_length)

        get_samples = theano.function(inputs=[],
                                      outputs=samples)

        return get_samples()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号