def step(self, time_, inputs, state, name=None):
cell_output, cell_state = self.cell(inputs, state)
cell_output_new, logits, attention_scores, attention_context = \
self.compute_output(cell_output)
if self.reverse_scores_lengths is not None:
attention_scores = tf.reverse_sequence(
input=attention_scores,
seq_lengths=self.reverse_scores_lengths,
seq_dim=1,
batch_dim=0)
sample_ids = self.helper.sample(
time=time_, outputs=logits, state=cell_state)
outputs = AttentionDecoderOutput(
logits=logits,
predicted_ids=sample_ids,
cell_output=cell_output_new,
attention_scores=attention_scores,
attention_context=attention_context)
finished, next_inputs, next_state = self.helper.next_inputs(
time=time_, outputs=outputs, state=cell_state, sample_ids=sample_ids)
return (outputs, next_state, next_inputs, finished)
评论列表
文章目录