def ar_layer(z0,hps,n_hidden=10):
''' old iaf layer '''
# Repeat input
z_rep = tf.reshape(tf.tile(z0,[1,hps.z_size]),[-1,hps.z_size])
# make mask
mask = tf.sequence_mask(tf.range(hps.z_size),hps.z_size)[None,:,:]
mask = tf.reshape(tf.tile(mask,[tf.shape(z0)[0],1,1]),[-1,hps.z_size])
# predict mu and sigma
z_mask = z_rep * tf.to_float(mask)
mid = slim.fully_connected(z_mask,n_hidden,activation_fn=tf.nn.relu)
pars = slim.fully_connected(mid,2,activation_fn=None)
pars = tf.reshape(pars,[-1,hps.z_size,2])
mu, log_sigma = tf.unstack(pars,axis=2)
return mu, log_sigma
评论列表
文章目录