def multinomial_sample(x, vocab_size, temperature):
"""Multinomial sampling from a n-dimensional tensor."""
if temperature > 0:
samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]) / temperature, 1)
else:
samples = tf.argmax(x, axis=-1)
reshaped_samples = tf.reshape(samples, common_layers.shape_list(x)[:-1])
return tf.to_int32(reshaped_samples)
评论列表
文章目录