e2c_seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:e2c-pytorch 作者: ethanluoyc 项目源码 文件源码
def sampleQ_psi(z,u,Q_phi,share=None):
  A,B,o,v,r=transition(z,share)
  with tf.variable_scope("sampleQ_psi"):
    mu_t=tf.expand_dims(Q_phi.mu,-1) # batch,z_dim,1
    Amu=tf.squeeze(tf.batch_matmul(A,mu_t), [-1])
    u=tf.expand_dims(u,-1) # batch,u_dim,1
    Bu=tf.squeeze(tf.batch_matmul(B,u),[-1])
    Q_psi=NormalDistribution(Amu+Bu+o,Q_phi.sigma,Q_phi.logsigma, v, r)
    # the actual z_next sample is generated by deterministically transforming z_t
    z=tf.expand_dims(z,-1)
    Az=tf.squeeze(tf.batch_matmul(A,z),[-1])
    z_next=Az+Bu+o
    return z_next,Q_psi#,(A,B,o,v,r) # debugging
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号