def _sample(self, n_samples):
logits, temperature = self.logits, self.temperature
if not self.is_reparameterized:
logits = tf.stop_gradient(logits)
temperature = tf.stop_gradient(temperature)
shape = tf.concat([[n_samples], tf.shape(self.logits)], 0)
uniform = open_interval_standard_uniform(shape, self.dtype)
# TODO: Add Gumbel distribution
gumbel = -tf.log(-tf.log(uniform))
samples = tf.nn.softmax((logits + gumbel) / temperature)
static_n_samples = n_samples if isinstance(n_samples, int) else None
samples.set_shape(
tf.TensorShape([static_n_samples]).concatenate(logits.get_shape()))
return samples
评论列表
文章目录