sru.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:benchmark 作者: pytorch 项目源码 文件源码
def forward(self, u, x, bias, init=None, mask_h=None):
        bidir = 2 if self.bidirectional else 1
        length = x.size(0) if x.dim() == 3 else 1
        batch = x.size(-2)
        d = self.d_out
        k = u.size(-1) // d
        k_ = k//2 if self.bidirectional else k

        u = u.view(length, batch, d, k_)

        cur = x.new(batch, d).zero_() if init is None else init
        size = (length, batch, d*bidir) if x.dim() == 3 else (batch, d*bidir)
        bias1, bias2 = bias.split(self.d_out)
        u_ = [u.select(-1, i) for i in range(0, k_)]
        h = []
        x_ = x if k_ == 3 else u_[3]
        for i in range(0, length):
            u0i, u1i, u2i = u_[0][i], u_[1][i], u_[2][i]
            g1 = torch.sigmoid(u1i + bias1)
            g2 = torch.sigmoid(u2i + bias2)
            cur = (cur - u0i)*g1 + u0i
            if self.activation_type == 1:
                val = torch.tanh(cur)
            elif self.activation_type == 2:
                val = torch.relu(cur)
            if mask_h is not None:
                val = val*mask_h
            xi = x_[i]
            h.append((val - xi)*g2 + xi)

        if self.bidirectional:
            assert False
        else:
            last_hidden = cur
        h = torch.stack(h)
        return h, last_hidden
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号