seq2seq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Variational-Recurrent-Autoencoder-Tensorflow 作者: Chung-I 项目源码 文件源码
def sample(means,
           logvars,
           latent_dim,
           iaf=True,
           kl_min=None,
           anneal=False,
           kl_rate=None,
           dtype=None):
  """Perform sampling and calculate KL divergence.

  Args:
    means: tensor of shape (batch_size, latent_dim)
    logvars: tensor of shape (batch_size, latent_dim)
    latent_dim: dimension of latent space.
    iaf: perform linear IAF or not.
    kl_min: lower bound for KL divergence.
    anneal: perform KL cost annealing or not.
    kl_rate: KL divergence is multiplied by kl_rate if anneal is set to True.
  Returns:
    latent_vector: latent variable after sampling. A vector of shape (batch_size, latent_dim).
    kl_obj: objective to be minimized for the KL term.
    kl_cost: real KL divergence.
  """
  if iaf:
    with tf.variable_scope('iaf'):
      prior = DiagonalGaussian(tf.zeros_like(means, dtype=dtype),
              tf.zeros_like(logvars, dtype=dtype))
      posterior = DiagonalGaussian(means, logvars)
      z = posterior.sample

      logqs = posterior.logps(z)
      L = tf.get_variable("inverse_cholesky", [latent_dim, latent_dim], dtype=dtype, initializer=tf.zeros_initializer)
      diag_one = tf.ones([latent_dim], dtype=dtype)
      L = tf.matrix_set_diag(L, diag_one)
      mask = np.tril(np.ones([latent_dim,latent_dim]))
      L = L * mask
      latent_vector = tf.matmul(z, L)
      logps = prior.logps(latent_vector)
      kl_cost = logqs - logps
  else:
    noise = tf.random_normal(tf.shape(mean))
    sample = mean + tf.exp(0.5 * logvar) * noise
    kl_cost = -0.5 * (logvars - tf.square(means) -
        tf.exp(logvars) + 1.0)
  kl_ave = tf.reduce_mean(kl_cost, [0]) #mean of kl_cost over batches
  kl_obj = kl_cost = tf.reduce_sum(kl_ave)
  if kl_min:
    kl_obj = tf.reduce_sum(tf.maximum(kl_ave, kl_min))
  if anneal:
    kl_obj = kl_obj * kl_rate

  return latent_vector, kl_obj, kl_cost #both kl_obj and kl_cost are scalar
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号