modules.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch_workplace 作者: DingKe 项目源码 文件源码
def forward(self, input, h):
        ih = F.linear(input, self.weight_ih, self.bias)
        hh_rz = F.linear(h, self.weight_hh_rz)

        if self.grad_clip:
            ih = clip_grad(ih, -self.grad_clip, self.grad_clip)
            hh_rz = clip_grad(hh_rz, -self.grad_clip, self.grad_clip)

        r = F.sigmoid(ih[:, :self.hidden_size] + hh_rz[:, :self.hidden_size])
        i = F.sigmoid(ih[:, self.hidden_size: self.hidden_size * 2] + hh_rz[:, self.hidden_size:])

        hhr = F.linear(h * r, self.weight_hh)
        if self.grad_clip:
            hhr = clip_grad(hhr, -self.grad_clip, self.grad_clip)

        n = F.relu(ih[:, self.hidden_size * 2:] + hhr)
        h = (1 - i) * n + i * h

        return h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号