def fit(samples, is_continuous):
'''
Fits a distribution to the given samples.
Parameters
----------
samples : array_like
Array of samples.
is_continuous : bool
If `True` then a continuous distribution is fitted. Otherwise, a
discrete distribution is fitted.
Returns
-------
best_marginal : Marginal
The distribution fitted to `samples`.
'''
# Mean and variance
mean = np.mean(samples)
var = np.var(samples)
# Set suitable distributions
if is_continuous:
if np.any(samples <= 0):
options = [norm]
else:
options = [norm, gamma]
else:
if var > mean:
options = [poisson, binom, nbinom]
else:
options = [poisson, binom]
params = np.empty(len(options), dtype=object)
marginals = np.empty(len(options), dtype=object)
# Fit parameters and construct marginals
for i, dist in enumerate(options):
if dist == poisson:
params[i] = [mean]
elif dist == binom:
param_n = np.max(samples)
param_p = np.sum(samples) / (param_n * len(samples))
params[i] = [param_n, param_p]
elif dist == nbinom:
param_n = mean * mean / (var - mean)
param_p = mean / var
params[i] = [param_n, param_p]
else:
params[i] = dist.fit(samples)
rv_mixed = dist(*params[i])
marginals[i] = Marginal(rv_mixed)
# Calculate Akaike information criterion
aic = np.zeros(len(options))
for i, marginal in enumerate(marginals):
aic[i] = 2 * len(params[i]) \
- 2 * np.sum(marginal.logpdf(samples))
best_marginal = marginals[np.argmin(aic)]
return best_marginal
评论列表
文章目录