def step(self, time, inputs, state : BeamSearchOptimizationDecoderState , name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with tf.name_scope(name, "BeamSearchOptimizationDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
with tf.name_scope('merge_cell_input'):
inputs = nest.map_structure(lambda x: self._merge_batch_beams(x, s=x.shape[2:]), inputs)
print('inputs', inputs)
with tf.name_scope('merge_cell_state'):
cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
with tf.name_scope('split_cell_outputs'):
cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs, self._output_size)
with tf.name_scope('split_cell_state'):
next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
beam_search_output, beam_search_state = self._beam_search_step(
time=time,
logits=cell_outputs,
next_cell_state=next_cell_state,
beam_state=state)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
next_inputs = self._embedding_fn(sample_ids)
return (beam_search_output, beam_search_state, next_inputs, finished)
beam_aligner.py 文件源码
python
阅读 41
收藏 0
点赞 0
评论 0
评论列表
文章目录