def Decoder(self, input, hidden_encoder, phis,
input_target=None, target=None):
feed_target = False
if target is not None:
feed_target = True
# N[:, n] is the number of elements of the scope of the n-th element
N = phis.sum(2).squeeze().unsqueeze(2).expand(self.batch_size, self.n,
self.hidden_size)
output = (Variable(torch.ones(self.batch_size, self.n, self.n + 1))
.type(dtype))
index = ((N[:, 0] - 1) % (self.n)).type(dtype_l).unsqueeze(1).detach()
hidden = (torch.gather(hidden_encoder, 1, index + 1)).squeeze()
# W1xe size: (batch_size, n + 1, hidden_size)
W1xe = (torch.bmm(hidden_encoder, self.W1.unsqueeze(0).expand(
self.batch_size, self.hidden_size, self.hidden_size)))
# init token
start = (self.init_token.unsqueeze(0).expand(self.batch_size,
self.input_size))
input_step = start
for n in xrange(self.n):
# decouple interaction between different scopes by looking at
# subdiagonal elements of Phi
if n > 0:
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.hidden_size))
index = (((N[:, n] + n - 1) % (self.n)).type(dtype_l)
.unsqueeze(1)).detach()
init_hidden = (torch.gather(hidden_encoder, 1, index + 1)
.squeeze())
hidden = t * hidden + (1 - t) * init_hidden
t = (phis[:, n, n - 1].squeeze().unsqueeze(1).expand(
self.batch_size, self.input_size))
input_step = t * input_step + (1 - t) * start
# Compute next state
hidden = self.decoder_cell(input_step, hidden)
# Compute pairwise interactions
u = self.attention(hidden, W1xe, hidden_encoder)
# Normalize interactions by taking the masked softmax by phi
pad = Variable(torch.ones(self.batch_size, 1)).type(dtype)
mask = torch.cat((pad, phis[:, n].squeeze()), 1)
attn = self.softmax_m(mask, u)
if feed_target:
# feed next step with target
next = (target[:, n].unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input_target, 1, next).squeeze()
else:
# not blend
index = attn.max(1)[1].squeeze()
next = (index.unsqueeze(1).unsqueeze(2)
.expand(self.batch_size, 1, self.input_size)
.type(dtype_l))
input_step = torch.gather(input, 1, next).squeeze()
# blend inputs
# input_step = (torch.sum(attn.unsqueeze(2).expand(
# self.batch_size, self. n + 1,
# self.input_size) * input, 1)).squeeze()
# Update output
output[:, n] = attn
return output
评论列表
文章目录