vae_conv.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def vae_conv(observed, n, n_x, n_z, n_particles, is_training):
    with zs.BayesianNet(observed=observed) as model:
        normalizer_params = {'is_training': is_training,
                             'updates_collections': None}
        z_mean = tf.zeros([n, n_z])
        z = zs.Normal('z', z_mean, std=1., n_samples=n_particles,
                      group_ndims=1)
        lx_z = tf.reshape(z, [-1, 1, 1, n_z])
        lx_z = layers.conv2d_transpose(
            lx_z, 128, kernel_size=3, padding='VALID',
            normalizer_fn=layers.batch_norm,
            normalizer_params=normalizer_params)
        lx_z = layers.conv2d_transpose(
            lx_z, 64, kernel_size=5, padding='VALID',
            normalizer_fn=layers.batch_norm,
            normalizer_params=normalizer_params)
        lx_z = layers.conv2d_transpose(
            lx_z, 32, kernel_size=5, stride=2,
            normalizer_fn=layers.batch_norm,
            normalizer_params=normalizer_params)
        lx_z = layers.conv2d_transpose(
            lx_z, 1, kernel_size=5, stride=2,
            activation_fn=None)
        x_logits = tf.reshape(lx_z, [n_particles, n, -1])
        x = zs.Bernoulli('x', x_logits, group_ndims=1)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号