def sample(self, n, max_length=None, z=None, **kwargs):
"""Sample with an optional conditional embedding `z`."""
if z is not None and z.shape[0].value != n:
raise ValueError(
'`z` must have a first dimension that equals `n` when given. '
'Got: %d vs %d' % (z.shape[0].value, n))
if self.hparams.conditional and z is None:
tf.logging.warning(
'Sampling from conditional model without `z`. Using random `z`.')
normal_shape = [n, self.hparams.z_size]
normal_dist = tf.contrib.distributions.Normal(
loc=tf.zeros(normal_shape), scale=tf.ones(normal_shape))
z = normal_dist.sample()
return self.decoder.sample(n, max_length, z, **kwargs)
评论列表
文章目录