e2c_plane.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:e2c-pytorch 作者: ethanluoyc 项目源码 文件源码
def sampleQ_psi(z, u, Q_phi):
    A, B, o, v, r = transition(z)
    with tf.variable_scope("sampleQ_psi"):
        mu_t = tf.expand_dims(Q_phi.mu, -1)  # batch,z_dim,1
        Amu = tf.squeeze(tf.batch_matmul(A, mu_t), [-1])
        u = tf.expand_dims(u, -1)  # batch,u_dim,1
        Bu = tf.squeeze(tf.batch_matmul(B, u), [-1])
        Q_psi = NormalDistribution(Amu + Bu + o, Q_phi.sigma, Q_phi.logsigma,
                                   v, r)
        # the actual z_next sample is generated by deterministically transforming z_t
        z = tf.expand_dims(z, -1)
        Az = tf.squeeze(tf.batch_matmul(A, z), [-1])
        z_next = Az + Bu + o
        return z_next, Q_psi  #,(A,B,o,v,r) # debugging
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号