def _sample(self, n_samples):
if self.logits.get_shape().ndims == 2:
logits_flat = self.logits
else:
logits_flat = tf.reshape(self.logits, [-1, self.n_categories])
samples_flat = tf.transpose(
tf.multinomial(logits_flat, n_samples * self.n_experiments))
shape = tf.concat([[n_samples, self.n_experiments],
self.batch_shape], 0)
samples = tf.reshape(samples_flat, shape)
static_n_samples = n_samples if isinstance(n_samples, int) else None
static_n_exps = self.n_experiments if isinstance(self.n_experiments,
int) else None
samples.set_shape(
tf.TensorShape([static_n_samples, static_n_exps]).
concatenate(self.get_batch_shape()))
samples = tf.reduce_sum(
tf.one_hot(samples, self.n_categories, dtype=self.dtype), axis=1)
return samples
评论列表
文章目录