dist_fixture.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def get_scipy_batch_logpdf(self, idx):
        if not self.scipy_arg_fn:
            return
        dist_params = self.get_dist_params(idx, wrap_tensor=False)
        dist_params_wrapped = self.get_dist_params(idx)
        dist_params = self._convert_logits_to_ps(dist_params)
        test_data = self.get_test_data(idx, wrap_tensor=False)
        test_data_wrapped = self.get_test_data(idx)
        shape = self.pyro_dist.shape(test_data_wrapped, **dist_params_wrapped)
        batch_log_pdf = []
        for i in range(len(test_data)):
            batch_params = {}
            for k in dist_params:
                param = np.broadcast_to(dist_params[k], shape)
                batch_params[k] = param[i]
            args, kwargs = self.scipy_arg_fn(**batch_params)
            if self.is_discrete:
                batch_log_pdf.append(self.scipy_dist.logpmf(test_data[i],
                                                            *args,
                                                            **kwargs))
            else:
                batch_log_pdf.append(self.scipy_dist.logpdf(test_data[i],
                                                            *args,
                                                            **kwargs))
        return batch_log_pdf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号