def build_vimco_loss(cfg, l, log_q_h):
"""Builds negative VIMCO loss as in the paper.
Reference: Variational Inference for Monte Carlo Objectives, Algorithm 1
https://arxiv.org/abs/1602.06725
"""
k, b = l.get_shape().as_list() # n_samples, batch_size
kf = tf.cast(k, tf.float32)
if cfg['optim/geometric_mean']:
# implicit multi-sample objective (importance-sampled ELBO)
l_logsumexp = tf.reduce_logsumexp(l, [0], keep_dims=True)
L_hat = l_logsumexp - tf.log(kf)
else:
# standard ELBO
L_hat = tf.reduce_mean(l, [0], keep_dims=True)
s = tf.reduce_sum(l, 0, keep_dims=True)
diag_mask = tf.expand_dims(tf.diag(tf.ones([k], dtype=tf.float32)), -1)
off_diag_mask = 1. - diag_mask
diff = tf.expand_dims(s - l, 0) # expand for proper broadcasting
l_i_diag = 1. / (kf - 1.) * diff * diag_mask
l_i_off_diag = off_diag_mask * tf.stack([l] * k)
l_i = l_i_diag + l_i_off_diag
if cfg['optim/geometric_mean']:
L_hat_minus_i = tf.reduce_logsumexp(l_i, [1]) - tf.log(kf)
w = tf.stop_gradient(tf.exp((l - l_logsumexp)))
else:
L_hat_minus_i = tf.reduce_mean(l_i, [1])
w = 1.
local_l = tf.stop_gradient(L_hat - L_hat_minus_i)
if not cfg['optim/geometric_mean']:
# correction factor for multiplying by 1. / (kf - 1.) above
# to verify this, work out 2x2 matrix of samples by hand
local_l = local_l * k
loss = local_l * log_q_h + w * l
return loss / tf.to_float(b)
评论列表
文章目录