grl_seq2seq.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow 作者: liuyuemaicha 项目源码 文件源码
def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, num_encoder_symbols,
                          num_decoder_symbols, embedding_size,
                          output_projection=None,
                          feed_previous=False,
                          dtype=None,
                          scope=None,
                          beam_search=True,
                          beam_size=10):
  with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
    if dtype is not None:
      scope.set_dtype(dtype)
    else:
      dtype = scope.dtype

    # Encoder.
    encoder_cell = rnn_cell.EmbeddingWrapper(
        cell, embedding_classes=num_encoder_symbols,
        embedding_size=embedding_size)
    _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)

    # Decoder.
    if output_projection is None:
      cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)

    if isinstance(feed_previous, bool):
      return embedding_rnn_decoder(decoder_inputs, encoder_state, cell, num_decoder_symbols, embedding_size,
                                   output_projection=output_projection,
                                   feed_previous=feed_previous,
                                   scope=scope,
                                   beam_search=beam_search,
                                   beam_size=beam_size)

    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
    def decoder(feed_previous_bool):
      reuse = None if feed_previous_bool else True
      with variable_scope.variable_scope(
          variable_scope.get_variable_scope(), reuse=reuse) as scope:
        outputs, state = embedding_rnn_decoder(
            decoder_inputs, encoder_state, cell, num_decoder_symbols,
            embedding_size, output_projection=output_projection,
            feed_previous=feed_previous_bool,
            update_embedding_for_previous=False,
            beam_search=beam_search,
            beam_size=beam_size)
        state_list = [state]
        if nest.is_sequence(state):
          state_list = nest.flatten(state)
        return outputs + state_list

    outputs_and_state = control_flow_ops.cond(feed_previous,
                                              lambda: decoder(True),
                                              lambda: decoder(False))
    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
    state_list = outputs_and_state[outputs_len:]
    state = state_list[0]
    if nest.is_sequence(encoder_state):
      state = nest.pack_sequence_as(structure=encoder_state,
                                    flat_sequence=state_list)
    return outputs_and_state[:outputs_len], state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号