def plot_dist_discrete(X, output, clusters, ax=None, Y=None, hist=True):
# Create a new axis?
if ax is None:
_, ax = plt.subplots()
# Set up x axis.
X = np.asarray(X, dtype=int)
x_max = max(X)
Y = range(int(x_max)+1)
X_hist = np.bincount(X) / float(len(X))
ax.bar(Y, X_hist, color='gray', edgecolor='none')
# Compute weighted pdfs
pdf = np.zeros((len(clusters), len(Y)))
W = [log(clusters[k].N) - log(float(len(X))) for k in clusters]
for i, k in enumerate(clusters):
pdf[i,:] = np.exp(
[W[i] + clusters[k].logpdf(None, {output:y}) for y in Y])
color, alpha = gu.curve_color(i)
ax.bar(Y, pdf[i,:], color=color, edgecolor='none', alpha=alpha)
# Plot the sum of pdfs.
ax.bar(
Y, np.sum(pdf, axis=0), color='none', edgecolor='black', linewidth=3)
ax.set_xlim([0, x_max+1])
# Title.
ax.set_title(clusters.values()[0].name())
return ax
评论列表
文章目录