def sample(params, eps, dist='gauss'):
""" utility function for sampling from distributions, given noise """
if 'bin' in dist:
logits = params[-1]
params = params[:-1]
if 'gauss' in dist:
mean, cov = params
s = mean + tf.sqrt(cov) * eps
elif 'gm' in dist:
means, covs, pi_logits = params
choices = tf.multinomial(pi_logits, num_samples=1)
batch_size = choices.get_shape()[0]
ids = tf.constant(list(range(batch_size)), dtype=tf.int64, shape=(batch_size, 1))
idx_tensor = tf.concat([ids, choices], axis=1)
chosen_means = tf.gather_nd(means, idx_tensor)
chosen_covs = tf.gather_nd(covs, idx_tensor)
s = chosen_means + tf.sqrt(chosen_covs) * eps
else:
raise NotImplementedError
if 'bin' in dist:
sig = tf.sigmoid(logits)
s = tf.concat([s, sig], axis=1)
return s
评论列表
文章目录