rnn.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:aRNNie 作者: MojoJolo 项目源码 文件源码
def generate(self, hidden, seed_ix, chars_counter):
        input_x = np.zeros((self.vocab_size, 1))
        input_x[seed_ix] = 1
        ixes = []

        for i in xrange(chars_counter):
            hidden = np.tanh(np.dot(self.param_w_xh, input_x) + np.dot(self.param_w_hh, hidden) + self.bias_hidden) # tanh
            output_y = np.dot(self.param_w_hy, hidden) + self.bias_output_y
            prob = self.softmax(output_y)
            ix = np.random.choice(range(self.vocab_size), p=prob.ravel())

            input_x = np.zeros((self.vocab_size, 1))
            input_x[ix] = 1

            ixes.append(ix)

        return [self.ix_to_char[ix] for ix in ixes]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号