def reparameterize(encoded, num_discrete, tau, hard=False,
rnd_sample=None, eps=1e-20):
eshp = encoded.get_shape().as_list()
print("encoded = ", eshp)
num_normal = eshp[1] - num_discrete
print 'num_normal = ', num_normal
logits_normal = encoded[:, 0:num_normal]
logits_gumbel = encoded[:, num_normal:eshp[1]]
# we reparameterize using both the N(0, I) and the gumbel(0, 1)
z_discrete, kl_discrete = gumbel_reparmeterization(logits_gumbel,
tau,
rnd_sample,
hard)
z_n, kl_n = gaussian_reparmeterization(logits_normal)
# merge and pad appropriately
z = tf.concat([z_n, z_discrete], axis=1)
return [slim.flatten(z),
slim.flatten(z_n),
slim.flatten(z_discrete),
slim.flatten(tf.nn.softmax(logits_gumbel)),
kl_n,
kl_discrete]
评论列表
文章目录