def get_scipy_batch_logpdf(self, idx):
if not self.scipy_arg_fn:
return
dist_params = self.get_dist_params(idx, wrap_tensor=False)
dist_params_wrapped = self.get_dist_params(idx)
dist_params = self._convert_logits_to_ps(dist_params)
test_data = self.get_test_data(idx, wrap_tensor=False)
test_data_wrapped = self.get_test_data(idx)
shape = self.pyro_dist.shape(test_data_wrapped, **dist_params_wrapped)
batch_log_pdf = []
for i in range(len(test_data)):
batch_params = {}
for k in dist_params:
param = np.broadcast_to(dist_params[k], shape)
batch_params[k] = param[i]
args, kwargs = self.scipy_arg_fn(**batch_params)
if self.is_discrete:
batch_log_pdf.append(self.scipy_dist.logpmf(test_data[i],
*args,
**kwargs))
else:
batch_log_pdf.append(self.scipy_dist.logpdf(test_data[i],
*args,
**kwargs))
return batch_log_pdf
评论列表
文章目录