def sampleQ_psi(z, u, Q_phi):
A, B, o, v, r = transition(z)
with tf.variable_scope("sampleQ_psi"):
mu_t = tf.expand_dims(Q_phi.mu, -1) # batch,z_dim,1
Amu = tf.squeeze(tf.batch_matmul(A, mu_t), [-1])
u = tf.expand_dims(u, -1) # batch,u_dim,1
Bu = tf.squeeze(tf.batch_matmul(B, u), [-1])
Q_psi = NormalDistribution(Amu + Bu + o, Q_phi.sigma, Q_phi.logsigma,
v, r)
# the actual z_next sample is generated by deterministically transforming z_t
z = tf.expand_dims(z, -1)
Az = tf.squeeze(tf.batch_matmul(A, z), [-1])
z_next = Az + Bu + o
return z_next, Q_psi #,(A,B,o,v,r) # debugging
评论列表
文章目录