def _sample_n(self, n, seed=None):
sample_shape = array_ops.concat(([n], array_ops.shape(self.logits)), 0)
logits = self.logits * array_ops.ones(sample_shape)
if logits.get_shape().ndims == 2:
logits_2d = logits
else:
logits_2d = array_ops.reshape(logits, [-1, self.event_size])
np_dtype = self.dtype.as_numpy_dtype()
minval = np.nextafter(np_dtype(0), np_dtype(1))
uniform = random_ops.random_uniform(shape=array_ops.shape(logits_2d),
minval=minval,
maxval=1,
dtype=self.dtype,
seed=seed)
gumbel = - math_ops.log(- math_ops.log(uniform))
noisy_logits = math_ops.div(gumbel + logits_2d, self.temperature)
samples = nn_ops.log_softmax(noisy_logits)
ret = array_ops.reshape(samples, sample_shape)
return ret
relaxed_onehot_categorical.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录