def sampleQ_psi(z,u,Q_phi,share=None):
A,B,o,v,r=transition(z,share)
with tf.variable_scope("sampleQ_psi"):
mu_t=tf.expand_dims(Q_phi.mu,-1) # batch,z_dim,1
Amu=tf.squeeze(tf.batch_matmul(A,mu_t), [-1])
u=tf.expand_dims(u,-1) # batch,u_dim,1
Bu=tf.squeeze(tf.batch_matmul(B,u),[-1])
Q_psi=NormalDistribution(Amu+Bu+o,Q_phi.sigma,Q_phi.logsigma, v, r)
# the actual z_next sample is generated by deterministically transforming z_t
z=tf.expand_dims(z,-1)
Az=tf.squeeze(tf.batch_matmul(A,z),[-1])
z_next=Az+Bu+o
return z_next,Q_psi#,(A,B,o,v,r) # debugging
评论列表
文章目录