train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:generating_sequences 作者: PFCM 项目源码 文件源码
def sample(model, data, sess, seed=None):
    """Draws a sample from the model

    Args:
        model (dict): return value of one of the get_model functions.
        data (dict): return from the get_data function.
        sess (tf.Session): a session in which to run the sampling ops.
        seed (Optional): either a sequence to feed into the first few batches
            or None, in which case just the GO symbol is fed in.

    Returns:
        str: the sample.
    """
    if 'inverse_vocab' not in data:
        data['inverse_vocab'] = {b: a for a, b in data['vocab'].items()}

    if seed is not None:
        # run it a bit to get a starting state
        state, inputs = _init_nextstep_state(model, data, seed, sess)
    else:
        # otherwise start from zero
        state = sess.run(model['sampling']['initial_state'])
        inputs = np.array(
            [[data['go_symbol']] * FLAGS.batch_size] * FLAGS.sequence_length)

    seq = []

    # now just roll through
    while len(seq) < FLAGS.sample_length:
        results = sess.run(
            model['sampling']['sequence'] + model['sampling']['final_state'],
            _fill_feed(data, inputs,
                       state_var=model['sampling']['initial_state'],
                       state_val=state))
        seq.extend(results[:FLAGS.sequence_length-2])
        state = results[FLAGS.sequence_length-1:]
        inputs = np.array(
            [seq[-1]] * FLAGS.sequence_length)

    batch_index = random.randint(0, FLAGS.batch_size)
    samp = ''.join([str(data['inverse_vocab'][symbol[batch_index]])
                    for symbol in seq])

    return samp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号