def forward(self, e, input, mask, scale=0):
hidden = Variable(torch.randn(self.batch_size, self.n,
self.hidden_size)).type(dtype)
if scale == 0:
e = Variable(torch.zeros(self.batch_size, self.n)).type(dtype)
Phi = self.build_Phi(e, mask)
N = torch.sum(Phi, 2).squeeze()
N += (N == 0).type(dtype) # avoid division by zero
Nh = N.unsqueeze(2).expand(self.batch_size, self.n,
self.hidden_size + self.input_size)
# Normalize inputs, important part!
mask_inp = mask.unsqueeze(2).expand_as(input)
input_n = self.Normalize_inputs(Phi, input) * mask_inp
# input_n = input * mask_inp
for i, layer in enumerate(self.layers):
hidden = layer(input_n, hidden, Phi, Nh)
hidden_p = hidden.view(self.batch_size * self.n, self.hidden_size)
scores = self.linear_b(hidden_p)
probs = torch.sigmoid(scores).view(self.batch_size, self.n) * mask
# probs has shape (batch_size, n)
return scores, probs, input_n, Phi
评论列表
文章目录