modules.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch_workplace 作者: DingKe 项目源码 文件源码
def forward(self, input, hx):
        h, c = hx

        pre = F.linear(input, self.weight_ih, self.bias) \
                    + F.linear(h, self.weight_hh)

        pre = sparsify_grad(pre, self.k, self.simplified)

        if self.grad_clip:
            pre = clip_grad(pre, -self.grad_clip, self.grad_clip)

        i = F.sigmoid(pre[:, :self.hidden_size])
        f = F.sigmoid(pre[:, self.hidden_size: self.hidden_size * 2])
        g = F.tanh(pre[:, self.hidden_size * 2: self.hidden_size * 3])
        o = F.sigmoid(pre[:, self.hidden_size * 3:])

        c = f * c + i * g
        h = o * F.tanh(c)
        return h, c
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号