goru.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码
def _EUNN(self, hx, thetaA, thetaB):

        L = self.capacity
        N = self.hidden_size

        sinA = torch.sin(self.thetaA)
        cosA = torch.cos(self.thetaA)
        sinB = torch.sin(self.thetaB)
        cosB = torch.cos(self.thetaB)

        I = Variable(torch.ones((L//2, 1)))
        O = Variable(torch.zeros((L//2, 1)))

        diagA = torch.stack((cosA, cosA), 2)
        offA = torch.stack((-sinA, sinA), 2)
        diagB = torch.stack((cosB, cosB), 2)
        offB = torch.stack((-sinB, sinB), 2)

        diagA = diagA.view(L//2, N)
        offA = offA.view(L//2, N)
        diagB = diagB.view(L//2, N-2)
        offB = offB.view(L//2, N-2)

        diagB = torch.cat((I, diagB, I), 1)
        offB = torch.cat((O, offB, O), 1)

        batch_size = hx.size()[0]
        x = hx
        for i in range(L//2):
#           # A
            y = x.view(batch_size, N//2, 2)
            y = torch.stack((y[:,:,1], y[:,:,0]), 2)
            y = y.view(batch_size, N)

            x = torch.mul(x, diagA[i].expand_as(x))
            y = torch.mul(y, offA[i].expand_as(x))

            x = x + y

            # B
            x_top = x[:,0]
            x_mid = x[:,1:-1].contiguous()
            x_bot = x[:,-1]
            y = x_mid.view(batch_size, N//2-1, 2)
            y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
            y = y.view(batch_size, N-2)
            x_top = torch.unsqueeze(x_top, 1)
            x_bot = torch.unsqueeze(x_bot, 1)
            # print x_top.size(), y.size(), x_bot.size()
            y = torch.cat((x_top, y, x_bot), 1)

            x = x * diagB[i].expand(batch_size, N)
            y = y * offB[i].expand(batch_size, N)

            x = x + y
        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号