model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:dgm 作者: ashwindcruz 项目源码 文件源码
def planar_flows(self,z):
        self.z_trans = []
        self.z_trans.append(z)
        self.phi = []

        for i in range(self.num_trans):
            flow_w_name = 'flow_w_' + str(i)
            flow_b_name = 'flow_b_' + str(i)
            flow_u_name = 'flow_u_' + str(i)

            h = self[flow_w_name](z)
            h = F.sum(h,axis=(1))
            h = self[flow_b_name](h)
            h = F.tanh(h)
            h_tanh = h

            dim_latent = z.shape[1]
            h = F.transpose(F.tile(h, (dim_latent,1)))
            h = self[flow_u_name](h)

            z += h

            self.z_trans.append(z)

            # Calculate and store the phi term
            h_tanh_derivative = 1-(h_tanh*h_tanh)
            h_tanh_derivative = F.transpose(F.tile(h_tanh_derivative, (dim_latent,1))) 

            phi = self[flow_w_name](h_tanh_derivative) # Equation (11)
            self.phi.append(phi)

        return z
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号