seq2seq.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:PTTChatBot_DL2017 作者: thisray 项目源码 文件源码
def tied_rnn_seq2seq(encoder_inputs,
                                         decoder_inputs,
                                         cell,
                                         loop_function=None,
                                         dtype=dtypes.float32,
                                         scope=None):
    """RNN sequence-to-sequence model with tied encoder and decoder parameters.

    This model first runs an RNN to encode encoder_inputs into a state vector, and
    then runs decoder, initialized with the last encoder state, on decoder_inputs.
    Encoder and decoder use the same RNN cell and share parameters.

    Args:
        encoder_inputs: A list of 2D Tensors [batch_size x input_size].
        decoder_inputs: A list of 2D Tensors [batch_size x input_size].
        cell: core_rnn_cell.RNNCell defining the cell function and size.
        loop_function: If not None, this function will be applied to i-th output
            in order to generate i+1-th input, and decoder_inputs will be ignored,
            except for the first element ("GO" symbol), see rnn_decoder for details.
        dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
        scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".

    Returns:
        A tuple of the form (outputs, state), where:
            outputs: A list of the same length as decoder_inputs of 2D Tensors with
                shape [batch_size x output_size] containing the generated outputs.
            state: The state of each decoder cell in each time-step. This is a list
                with length len(decoder_inputs) -- one item for each time-step.
                It is a 2D Tensor of shape [batch_size x cell.state_size].
    """
    with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
        scope = scope or "tied_rnn_seq2seq"
        _, enc_state = core_rnn.static_rnn(
                cell, encoder_inputs, dtype=dtype, scope=scope)
        variable_scope.get_variable_scope().reuse_variables()
        return rnn_decoder(
                decoder_inputs,
                enc_state,
                cell,
                loop_function=loop_function,
                scope=scope)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号