def probabilities(self, Y):
"""Probabilities of parameter vectors Y
Parameters
----------
Y : array-like, shape (num_samples, n_weights)
2d array of parameter vectors
Returns
----------
resp : array-like, shape (num_samples)
the probabilities of the samples under this policy
"""
if not compatible_version("scipy", ">= 0.14"):
raise ImportError(
"SciPy >= 0.14 is required for "
"'scipy.stats.multivariate_normal'.")
from scipy.stats import multivariate_normal
return multivariate_normal(mean=self.mean, cov=self.Sigma).pdf(Y)
评论列表
文章目录