def decode_once(self, x, state, train=True):
c = state['c']
h = state['h']
h_tilde = state.get('h_tilde', None)
emb = self.trg_emb(x)
lstm_in = self.eh(emb) + self.hh(h)
if h_tilde is not None:
lstm_in += self.ch(h_tilde)
c, h = F.lstm(c, lstm_in)
a = self.attender(h, train=train)
h_tilde = F.concat([a, h])
h_tilde = F.tanh(self.w_c(h_tilde))
o = self.ho(h_tilde)
state['c'] = c
state['h'] = h
state['h_tilde'] = h_tilde
return o, state
评论列表
文章目录