def q_net(x, n_xl, n_z, n_particles, is_training):
with zs.BayesianNet() as variational:
normalizer_params = {'is_training': is_training,
'updates_collections': None}
lz_x = tf.reshape(tf.to_float(x), [-1, n_xl, n_xl, 1])
lz_x = layers.conv2d(
lz_x, 32, kernel_size=5, stride=2,
normalizer_fn=layers.batch_norm,
normalizer_params=normalizer_params)
lz_x = layers.conv2d(
lz_x, 64, kernel_size=5, stride=2,
normalizer_fn=layers.batch_norm,
normalizer_params=normalizer_params)
lz_x = layers.conv2d(
lz_x, 128, kernel_size=5, padding='VALID',
normalizer_fn=layers.batch_norm,
normalizer_params=normalizer_params)
lz_x = layers.dropout(lz_x, keep_prob=0.9, is_training=is_training)
lz_x = tf.reshape(lz_x, [-1, 128 * 3 * 3])
lz_mean = layers.fully_connected(lz_x, n_z, activation_fn=None)
lz_logstd = layers.fully_connected(lz_x, n_z, activation_fn=None)
z = zs.Normal('z', lz_mean, logstd=lz_logstd, n_samples=n_particles,
group_ndims=1)
return variational
评论列表
文章目录