def _EUNN(self, hx, thetaA, thetaB):
L = self.capacity
N = self.hidden_size
sinA = torch.sin(self.thetaA)
cosA = torch.cos(self.thetaA)
sinB = torch.sin(self.thetaB)
cosB = torch.cos(self.thetaB)
I = Variable(torch.ones((L//2, 1)))
O = Variable(torch.zeros((L//2, 1)))
diagA = torch.stack((cosA, cosA), 2)
offA = torch.stack((-sinA, sinA), 2)
diagB = torch.stack((cosB, cosB), 2)
offB = torch.stack((-sinB, sinB), 2)
diagA = diagA.view(L//2, N)
offA = offA.view(L//2, N)
diagB = diagB.view(L//2, N-2)
offB = offB.view(L//2, N-2)
diagB = torch.cat((I, diagB, I), 1)
offB = torch.cat((O, offB, O), 1)
batch_size = hx.size()[0]
x = hx
for i in range(L//2):
# # A
y = x.view(batch_size, N//2, 2)
y = torch.stack((y[:,:,1], y[:,:,0]), 2)
y = y.view(batch_size, N)
x = torch.mul(x, diagA[i].expand_as(x))
y = torch.mul(y, offA[i].expand_as(x))
x = x + y
# B
x_top = x[:,0]
x_mid = x[:,1:-1].contiguous()
x_bot = x[:,-1]
y = x_mid.view(batch_size, N//2-1, 2)
y = torch.stack((y[:, :, 1], y[:, :, 0]), 1)
y = y.view(batch_size, N-2)
x_top = torch.unsqueeze(x_top, 1)
x_bot = torch.unsqueeze(x_bot, 1)
# print x_top.size(), y.size(), x_bot.size()
y = torch.cat((x_top, y, x_bot), 1)
x = x * diagB[i].expand(batch_size, N)
y = y * offB[i].expand(batch_size, N)
x = x + y
return x
评论列表
文章目录