def sample(model, data, sess, seed=None):
"""Draws a sample from the model
Args:
model (dict): return value of one of the get_model functions.
data (dict): return from the get_data function.
sess (tf.Session): a session in which to run the sampling ops.
seed (Optional): either a sequence to feed into the first few batches
or None, in which case just the GO symbol is fed in.
Returns:
str: the sample.
"""
if 'inverse_vocab' not in data:
data['inverse_vocab'] = {b: a for a, b in data['vocab'].items()}
if seed is not None:
# run it a bit to get a starting state
state, inputs = _init_nextstep_state(model, data, seed, sess)
else:
# otherwise start from zero
state = sess.run(model['sampling']['initial_state'])
inputs = np.array(
[[data['go_symbol']] * FLAGS.batch_size] * FLAGS.sequence_length)
seq = []
# now just roll through
while len(seq) < FLAGS.sample_length:
results = sess.run(
model['sampling']['sequence'] + model['sampling']['final_state'],
_fill_feed(data, inputs,
state_var=model['sampling']['initial_state'],
state_val=state))
seq.extend(results[:FLAGS.sequence_length-2])
state = results[FLAGS.sequence_length-1:]
inputs = np.array(
[seq[-1]] * FLAGS.sequence_length)
batch_index = random.randint(0, FLAGS.batch_size)
samp = ''.join([str(data['inverse_vocab'][symbol[batch_index]])
for symbol in seq])
return samp
评论列表
文章目录