def forward(self, x, ident, context, start=True):
out, attns = [], []
o_t = x[0]
self.init_buffer(ident, start)
for o_tm1 in torch.split(x, 1):
if not self.training:
o_tm1 = o_t.unsqueeze(0)
# predict weighted context based on S
c_t, mu_t, alpha_t = self.attn(self.S_t,
context.transpose(0, 1),
self.mu_t)
# advance mu and update buffer
self.S_t = self.update_buffer(self.S_t, c_t, o_tm1, ident)
self.mu_t = mu_t
# predict next time step based on buffer content
ot_out = self.N_o(self.S_t.view(self.S_t.size(0), -1))
sp_out = self.F_o(ident)
o_t = self.output(ot_out + sp_out)
out += [o_t]
attns += [alpha_t.squeeze()]
out_seq = torch.stack(out)
attns_seq = torch.stack(attns)
return out_seq, attns_seq
评论列表
文章目录