def sampler(symbols_to_logits_fn, initial_ids, sample_num, decode_length,
vocab_size, eos_id, features=None):
batch_size = tf.shape(initial_ids)[0]
# Expand each batch to sample_num
seqlen = tf.constant(0)
alive_seq = tf.tile(tf.expand_dims(initial_ids, 1), [1, sample_num])
alive_seq = tf.expand_dims(alive_seq, 2) # (batch_size, sample_num, 1)
sa = tf.shape(alive_seq)
alive_seq = tf.reshape(alive_seq, [sa[0]*sa[1],1])
def _is_finished(i, alive_seq):
return i < decode_length
def inner_loop(i, alive_seq):
logit = symbols_to_logits_fn(alive_seq)[0]
new_samples = tf.multinomial(logit, 1)
new_samples = tf.to_int32(new_samples)
alive_seq = tf.concat([alive_seq, new_samples], 1)
return (i + 1, alive_seq)
(_, alive_seq) = tf.while_loop(
_is_finished,
inner_loop,
[seqlen, alive_seq],
shape_invariants=[
tf.TensorShape([]),
tf.TensorShape([None, None])
],
parallel_iterations=1,
back_prop=False
)
alive_seq.set_shape((sample_num, None))
return alive_seq
评论列表
文章目录