common_layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def sample_with_temperature(logits, temperature):
  """Either argmax or random sampling.

  Args:
    logits: a Tensor.
    temperature: a float  0.0=argmax 1.0=random

  Returns:
    a Tensor with one fewer dimension than logits.
  """
  if temperature == 0.0:
    return tf.argmax(logits, -1)
  else:
    assert temperature > 0.0
    reshaped_logits = (
        tf.reshape(logits, [-1, shape_list(logits)[-1]]) / temperature)
    choices = tf.multinomial(reshaped_logits, 1)
    choices = tf.reshape(choices,
                         shape_list(logits)[:logits.get_shape().ndims - 1])
    return choices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号