recurrent.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:NumpyDL 作者: oujago 项目源码 文件源码
def forward(self, input, *args, **kwargs):
        assert np.ndim(input) == 3, 'Only support batch training.'

        self.last_input = input
        nb_batch, nb_timestep, nb_in = input.shape
        output = _zero((nb_batch, nb_timestep, self.n_out))

        if len(self.activations) == 0:
            self.activations = [self.activation_cls() for _ in range(nb_timestep)]

        output[:, 0, :] = self.activations[0].forward(np.dot(input[:, 0, :], self.W) + self.b)

        for i in range(1, nb_timestep):
            output[:, i, :] = self.activations[i].forward(
                np.dot(input[:, i, :], self.W) +
                np.dot(output[:, i - 1, :], self.U) + self.b)

        self.last_output = output
        if self.return_sequence:
            return self.last_output
        else:
            return self.last_output[:, -1, :]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号