seq2seq_ops.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def rnn_decoder(decoder_inputs, initial_state, cell, scope=None):
  """RNN Decoder that creates training and sampling sub-graphs.

  Args:
    decoder_inputs: Inputs for decoder, list of tensors.
      This is used only in training sub-graph.
    initial_state: Initial state for the decoder.
    cell: RNN cell to use for decoder.
    scope: Scope to use, if None new will be produced.

  Returns:
    List of tensors for outputs and states for training and sampling sub-graphs.
  """
  with vs.variable_scope(scope or "dnn_decoder"):
    states, sampling_states = [initial_state], [initial_state]
    outputs, sampling_outputs = [], []
    with ops.name_scope("training", values=[decoder_inputs, initial_state]):
      for i, inp in enumerate(decoder_inputs):
        if i > 0:
          vs.get_variable_scope().reuse_variables()
        output, new_state = cell(inp, states[-1])
        outputs.append(output)
        states.append(new_state)
    with ops.name_scope("sampling", values=[initial_state]):
      for i, _ in enumerate(decoder_inputs):
        if i == 0:
          sampling_outputs.append(outputs[i])
          sampling_states.append(states[i])
        else:
          sampling_output, sampling_state = cell(sampling_outputs[-1],
                                                 sampling_states[-1])
          sampling_outputs.append(sampling_output)
          sampling_states.append(sampling_state)
  return outputs, states, sampling_outputs, sampling_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号