rnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:aRNNie 作者: MojoJolo 项目源码 文件源码
def forward(self, inputs, targets, hidden_prev):
        # s = vector
        input_xs = {}
        hidden_s = {}
        output_ys = {}
        probs = {} # probablity

        hidden_s[-1] = np.copy(hidden_prev)
        loss = 0

        for i in xrange(len(inputs)):
            # Creating an equivalent one hot vector for each inputs
            input_xs[i] = np.zeros((self.vocab_size, 1))
            input_xs[i][inputs[i]] = 1

            # Calculating the current hidden state using the previous hiden state through tanh
            hidden_s[i] = self.tanh(self.param_w_xh, input_xs[i], self.param_w_hh, hidden_s[i - 1], self.bias_hidden)
            output_ys[i] = np.dot(self.param_w_hy, hidden_s[i]) + self.bias_output_y

            probs[i] = self.softmax(output_ys[i])
            loss += -np.log(probs[i][targets[i], 0])

        return input_xs, output_ys, hidden_s, probs, loss

    # backprop
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号