def readout_gru(self, target_prev_char_seq, target_prev_char_aux, input_states):
embeddings = self.lookup.apply(target_prev_char_seq)
gru_out = self.igru.apply(
**merge(self.gru_fork.apply(embeddings, as_dict=True),
{'mask': target_prev_char_aux, 'input_states': input_states}))
if self.igru_depth > 1:
gru_out = gru_out[-1]
readout_chars = self.gru_to_softmax.apply(gru_out)
return readout_chars
评论列表
文章目录