def step(self, time, inputs, state, name=None):
with tf.name_scope(name, "GrammarDecodingStep", (time, inputs, state)):
decoder_state, grammar_state = state
cell_outputs, cell_state = self._cell(inputs, decoder_state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
grammar_cell_outputs = self._grammar_helper.constrain_logits(cell_outputs, grammar_state)
cell_outputs = grammar_cell_outputs
sample_ids = self._helper.sample(time=time, outputs=grammar_cell_outputs, state=cell_state)
(finished, next_inputs, next_decoder_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
if self._fixed_outputs is not None:
next_grammar_state = self._grammar_helper.transition(grammar_state, self._fixed_outputs.read(time), self.batch_size)
else:
next_grammar_state = self._grammar_helper.transition(grammar_state, sample_ids, self.batch_size)
next_state = (next_decoder_state, next_grammar_state)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
grammar_decoder.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录