next_step.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def standard_nextstep_sample(cell, inputs, output_size, embedding, scope=None,
                             argmax=False, softmax_temperature=1):
    """Generate samples from the standard next step prediction model.
    Assumes that we are modelling sequence of discrete symbols.

    Args:
        cell (tf.nn.rnn_cell.RNNCell): a cell to reproduce the model.
        inputs: input variable, all but the first is ignored.
        output_size (int): the size of the vocabulary.
        embedding: the embedding matrix used.
        scope (Optional): variable scope, needs to line up with what was used
            to make the model for inference the first time around.
        argmax (Optional[bool]): if True, then instead of sampling we simply
            take the argmax of the logits, if False we put a softmax on
            first. Defaults to False.
        softmax_temperature (Optional[bool]): the temparature for softmax.
            The logits are divided by this before feeding into the softmax:
            a high value means higher entropy. Default 1.

    Returns:
        tuple: (initial_state, sequence, final_state) where
            - *initial_state* is a variable for the starting state of the net.
            - *sequence* is a list of len(inputs) length containing the sampled
                symbols.
            - *final_state* the finishing state of the network, to pass along.
    """
    # have to be quite careful to ensure the scopes line up
    with tf.variable_scope(scope or 'rnn') as scope:
        inputs = tf.unpack(inputs)
        batch_size = inputs[0].get_shape()[0].value
        initial_state = cell.zero_state(batch_size, tf.float32)

        # get the output weights
        with tf.variable_scope('output_layer'):
            weights = tf.get_variable('weights',
                                      [cell.output_size, output_size])
            biases = tf.get_variable('bias',
                                     [output_size])

        # choose an appropriate loop function
        sequence = []
        if argmax:
            loop_fn = argmax_and_embed(embedding, output_list=sequence,
                                       output_projection=(weights, biases))
        else:
            loop_fn = sample_and_embed(embedding, softmax_temperature,
                                       output_list=sequence,
                                       output_projection=(weights, biases))

        all_states, fstate = tf.nn.seq2seq.rnn_decoder(
            inputs, initial_state, cell, loop_function=loop_fn, scope=scope)

    return [initial_state], sequence, [fstate]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号