base_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def sample(self, n, max_length=None, z=None, **kwargs):
    """Sample with an optional conditional embedding `z`."""
    if z is not None and z.shape[0].value != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0].value, n))

    if self.hparams.conditional and z is None:
      tf.logging.warning(
          'Sampling from conditional model without `z`. Using random `z`.')
      normal_shape = [n, self.hparams.z_size]
      normal_dist = tf.contrib.distributions.Normal(
          loc=tf.zeros(normal_shape), scale=tf.ones(normal_shape))
      z = normal_dist.sample()

    return self.decoder.sample(n, max_length, z, **kwargs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号