util.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:proximity_vi 作者: altosaar 项目源码 文件源码
def build_vimco_loss(cfg, l, log_q_h):
  """Builds negative VIMCO loss as in the paper.

  Reference: Variational Inference for Monte Carlo Objectives, Algorithm 1
  https://arxiv.org/abs/1602.06725
  """
  k, b = l.get_shape().as_list()  # n_samples, batch_size
  kf = tf.cast(k, tf.float32)
  if cfg['optim/geometric_mean']:
    # implicit multi-sample objective (importance-sampled ELBO)
    l_logsumexp = tf.reduce_logsumexp(l, [0], keep_dims=True)
    L_hat = l_logsumexp - tf.log(kf)
  else:
    # standard ELBO
    L_hat = tf.reduce_mean(l, [0], keep_dims=True)
  s = tf.reduce_sum(l, 0, keep_dims=True)
  diag_mask = tf.expand_dims(tf.diag(tf.ones([k], dtype=tf.float32)), -1)
  off_diag_mask = 1. - diag_mask
  diff = tf.expand_dims(s - l, 0)  # expand for proper broadcasting
  l_i_diag = 1. / (kf - 1.) * diff * diag_mask
  l_i_off_diag = off_diag_mask * tf.stack([l] * k)
  l_i = l_i_diag + l_i_off_diag
  if cfg['optim/geometric_mean']:
    L_hat_minus_i = tf.reduce_logsumexp(l_i, [1]) - tf.log(kf)
    w = tf.stop_gradient(tf.exp((l - l_logsumexp)))
  else:
    L_hat_minus_i = tf.reduce_mean(l_i, [1])
    w = 1.
  local_l = tf.stop_gradient(L_hat - L_hat_minus_i)
  if not cfg['optim/geometric_mean']:
    # correction factor for multiplying by 1. / (kf - 1.) above
    # to verify this, work out 2x2 matrix of samples by hand
    local_l = local_l * k
  loss = local_l * log_q_h + w * l
  return loss / tf.to_float(b)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号