def forward(self, input, hidden, context, init_output):
emb = self.word_lut(input)
# n.b. you can increase performance if you compute W_ih * x for all
# iterations in parallel, but that's only possible if
# self.input_feed=False
outputs = []
output = init_output
for emb_t in emb.split(1):
emb_t = emb_t.squeeze(0)
if self.input_feed:
emb_t = torch.cat([emb_t, output], 1)
output, hidden = self.rnn(emb_t, hidden)
output, attn = self.attn(output, context.t())
output = self.dropout(output)
outputs += [output]
outputs = torch.stack(outputs)
return outputs, hidden, attn
评论列表
文章目录