def vae(observed, n, n_x, n_z, n_k, tau, n_particles, relaxed=False):
with zs.BayesianNet(observed=observed) as model:
z_stacked_logits = tf.zeros([n, n_z, n_k])
if relaxed:
z = zs.ExpConcrete('z', tau, z_stacked_logits,
n_samples=n_particles, group_ndims=1)
z = tf.exp(tf.reshape(z, [n_particles, n, n_z * n_k]))
else:
z = zs.OnehotCategorical(
'z', z_stacked_logits, n_samples=n_particles, group_ndims=1,
dtype=tf.float32)
z = tf.reshape(z, [n_particles, n, n_z * n_k])
lx_z = tf.layers.dense(z, 200, activation=tf.tanh)
lx_z = tf.layers.dense(lx_z, 200, activation=tf.tanh)
x_logits = tf.layers.dense(lx_z, n_x)
x = zs.Bernoulli('x', x_logits, group_ndims=1)
return model
评论列表
文章目录