def get_samples(self, n):
"""Sample the GMM distribution.
Arguments
---------
n : int
Number of samples needed
Returns
-------
1D array
Samples from the distribution
"""
normalized_w = self.weights / np.sum(self.weights)
get_rand_index = st.rv_discrete(values=(range(self.N),
normalized_w)).rvs(size=n)
samples = np.zeros(n)
k = 0
j = 0
while (k < n):
i = get_rand_index[j]
j = j + 1
if (j == n):
get_rand_index = st.rv_discrete(values=(range(self.N),
normalized_w)).rvs(size=n)
j = 0
v = np.random.normal(loc=self.points[i], scale=self.sigma[i])
if (v > self.max_limit or v < self.min_limit):
continue
else:
samples[k] = v
k = k + 1
if (k == n):
break
return samples
评论列表
文章目录