def Decoder(self, input, hidden_encoder, phis,
input_target=None, target=None):
feed_target = False
if target is not None:
feed_target = True
# N_n is the number of elements of the scope of the n-th element
N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
self.hidden_size)
output = (Variable(torch.ones(self.batch_size, self.n, self.n))
.type(dtype))
index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1)
hidden = (torch.gather(hidden_encoder, 1, index)).squeeze()
# W1xe size: (batch_size, n + 1, hidden_size)
W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
self.batch_size, self.hidden_size, self.hidden_size)))
# init token
start = (self.init_token.unsqueeze(0).expand(self.batch_size,
self.input_size))
input_step = start
for n in xrange(self.n):
# decouple interaction between different scopes by looking at
# subdiagonal elements of Phi
if n > 0:
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.hidden_size))
index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
.unsqueeze(1))
init_hidden = (torch.gather(hidden_encoder, 1, index)
.squeeze())
hidden = t * hidden + (1 - t) * init_hidden
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.input_size))
input_step = t * input_step + (1 - t) * start
# Compute next state
hidden = self.decoder_cell(input_step, hidden)
# Compute pairwise interactions
u = self.attention(hidden, W1xe, hidden_encoder, tanh=True)
# Normalize interactions by taking the masked softmax by phi
attn = self.softmax_m(phis[:, n].squeeze(), u)
if feed_target:
# feed next step with target
next = (target[:, n].unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input_target, 1, next).squeeze()
else:
# blend inputs
input_step = (torch.sum(attn.unsqueeze(2).expand(
self.batch_size, self. n,
self.input_size) * input, 1)).squeeze()
# Update output
output[:, n] = attn
return output
评论列表
文章目录