def forward(self, input_n, hidden, phi, nh):
hidden = torch.cat((hidden, input_n), 2)
# Aggregate reresentations
h_conv = torch.div(torch.bmm(phi, hidden), nh)
hidden = hidden.view(-1, self.hidden_size + self.input_size)
h_conv = h_conv.view(-1, self.hidden_size + self.input_size)
# h_conv has shape (batch_size, n, hidden_size + input_size)
m1 = (torch.mm(hidden, self.W1)
.view(self.batch_size, -1, self.hidden_size))
m2 = (torch.mm(h_conv, self.W2)
.view(self.batch_size, -1, self.hidden_size))
m3 = self.b.unsqueeze(0).unsqueeze(1).expand_as(m2)
hidden = torch.sigmoid(m1 + m2 + m3)
return hidden
评论列表
文章目录