plots.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:cgpm 作者: probcomp 项目源码 文件源码
def plot_dist_discrete(X, output, clusters, ax=None, Y=None, hist=True):
    # Create a new axis?
    if ax is None:
        _, ax = plt.subplots()
    # Set up x axis.
    X = np.asarray(X, dtype=int)
    x_max = max(X)
    Y = range(int(x_max)+1)
    X_hist = np.bincount(X) / float(len(X))
    ax.bar(Y, X_hist, color='gray', edgecolor='none')
    # Compute weighted pdfs
    pdf = np.zeros((len(clusters), len(Y)))
    W = [log(clusters[k].N) - log(float(len(X))) for k in clusters]
    for i, k in enumerate(clusters):
        pdf[i,:] = np.exp(
            [W[i] + clusters[k].logpdf(None, {output:y}) for y in Y])
        color, alpha = gu.curve_color(i)
        ax.bar(Y, pdf[i,:], color=color, edgecolor='none', alpha=alpha)
    # Plot the sum of pdfs.
    ax.bar(
        Y, np.sum(pdf, axis=0), color='none', edgecolor='black', linewidth=3)
    ax.set_xlim([0, x_max+1])
    # Title.
    ax.set_title(clusters.values()[0].name())
    return ax
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号