def update_buffer(self, S_tm1, c_t, o_tm1, ident):
# concat previous output & context
idt = torch.tanh(self.F_u(ident))
o_tm1 = o_tm1.squeeze(0)
z_t = torch.cat([c_t + idt, o_tm1/30], 1)
z_t = z_t.unsqueeze(2)
Sp = torch.cat([z_t, S_tm1[:, :, :-1]], 2)
# update S
u = self.N_u(Sp.view(Sp.size(0), -1))
u[:, :idt.size(1)] = u[:, :idt.size(1)] + idt
u = u.unsqueeze(2)
S = torch.cat([u, S_tm1[:, :, :-1]], 2)
return S
评论列表
文章目录