def decoder(tau, logits_y):
y = tf.reshape(gumbel_softmax(logits_y, tau, hard=False),
[-1, FLAGS.num_cat_dists, FLAGS.num_classes])
# Generative model p(x|y), i.e. the decoder (shape=(batch_size, 200))
net = slim.stack(slim.flatten(y),
slim.fully_connected,
[256, 512])
logits_x = slim.fully_connected(net,
784,
activation_fn=None)
# (shape=(batch_size, 784))
p_x = bernoulli(logits=logits_x)
return p_x
vae_gumbel_softmax.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录