recurrent.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def forward(self, input, *args, **kwargs):
        assert np.ndim(input) == 3, 'Only support batch training.'

        # record
        self.last_input = input

        # dim
        nb_batch, nb_timesteps, nb_in = input.shape

        # outputs
        output = _zero((nb_batch, nb_timesteps, self.n_out))

        # forward
        for i in range(nb_timesteps):
            # data
            s_pre = _zero((nb_batch, self.n_out)) if i == 0 else output[:, i - 1, :]
            x_now = input[:, i, :]

            # computation
            z_now = self.gate_activation.forward(np.dot(x_now, self.U_z) +
                                                 np.dot(s_pre, self.W_z) +
                                                 self.b_z)
            r_now = self.gate_activation.forward(np.dot(x_now, self.U_r) +
                                                 np.dot(s_pre, self.W_r) +
                                                 self.b_r)
            h_now = self.activation.forward(np.dot(x_now, self.U_h) +
                                            np.dot(s_pre * r_now, self.W_h) +
                                            self.b_h)
            output[:, i, :] = (1 - z_now) * h_now + z_now * s_pre

        # record
        self.last_output = output

        # return
        if self.return_sequence:
            return self.last_output
        else:
            return self.last_output[:, -1, :]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号