def processDecLSTMOneStep(self, decInputEmb, lstm_states_in,
finalHS, args, dropout_rate):
# 1, RNN???????????
# ?beam search?????????????????
self.model.decLSTM.setAllLSTMStates(lstm_states_in)
# 2, ??????????input feed???
if self.flag_dec_ifeed == 0: # inputfeed?????
wenbed = decInputEmb
elif self.flag_dec_ifeed == 1: # inputfeed??? (default)
wenbed = chaFunc.concat((finalHS, decInputEmb))
# elif self.flag_dec_ifeed == 2: # decInputEmb????? (debug?)
# wenbed = finalHS
else:
assert 0, "ERROR"
# 3? N???RNN???????
h1 = self.model.decLSTM.processOneStepForward(
wenbed, args, dropout_rate)
# 4, ???????????LSTM???????
lstm_states_out = self.model.decLSTM.getAllLSTMStates()
return h1, lstm_states_out
# attention???
评论列表
文章目录