test_iaf.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def _test_jacobian(self, input_dim, hidden_dim):
        jacobian = torch.zeros(input_dim, input_dim)
        iaf = InverseAutoregressiveFlow(input_dim, hidden_dim, sigmoid_bias=0.5)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        x = Variable(torch.randn(1, input_dim))
        iaf_x = iaf(x)
        for j in range(input_dim):
            for k in range(input_dim):
                epsilon_vector = torch.zeros(1, input_dim)
                epsilon_vector[0, j] = self.epsilon
                iaf_x_eps = iaf(x + Variable(epsilon_vector))
                delta = (iaf_x_eps - iaf_x) / self.epsilon
                jacobian[j, k] = float(delta[0, k].data.cpu().numpy()[0])

        permutation = iaf.get_arn().get_permutation()
        permuted_jacobian = jacobian.clone()
        for j in range(input_dim):
            for k in range(input_dim):
                permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]

        analytic_ldt = iaf.log_det_jacobian(iaf_x).data.cpu().numpy()[0]
        numeric_ldt = torch.sum(torch.log(torch.diag(permuted_jacobian)))
        ldt_discrepancy = np.fabs(analytic_ldt - numeric_ldt)

        diag_sum = torch.sum(torch.diag(nonzero(permuted_jacobian)))
        lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=-1))

        self.assertTrue(ldt_discrepancy < self.epsilon)
        self.assertTrue(diag_sum == float(input_dim))
        self.assertTrue(lower_sum == float(0.0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号