def sample_with_temperature(logits, temperature):
"""Either argmax or random sampling.
Args:
logits: a Tensor.
temperature: a float 0.0=argmax 1.0=random
Returns:
a Tensor with one fewer dimension than logits.
"""
if temperature == 0.0:
return tf.argmax(logits, -1)
else:
assert temperature > 0.0
reshaped_logits = (
tf.reshape(logits, [-1, shape_list(logits)[-1]]) / temperature)
choices = tf.multinomial(reshaped_logits, 1)
choices = tf.reshape(choices,
shape_list(logits)[:logits.get_shape().ndims - 1])
return choices
评论列表
文章目录