TestSurrogateBound.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:bnpy 作者: bnpy 项目源码 文件源码
def plotBoundVsK(KVals=np.arange(1,50),
                 alpha=0.5,
                 gamma=10,
                 labels=None,
                 betaFunc='prior'):
    if labels is None:
        txtlabel = str(alpha)
        labels = [None, None]
    else:
        txtlabel = 'alpha\n' + str(alpha)
    exactVals = np.zeros(len(KVals))
    boundVals = np.zeros(len(KVals))
    for ii, K in enumerate(KVals):
        betaVec = 1.0/(1.0 + gamma) * np.ones(K+1)
        for k in range(1, K):
            betaVec[k] = betaVec[k] * (1 - np.sum(betaVec[:k]))
        betaVec[-1] = 1 - np.sum(betaVec[:-1])
        print betaVec
        assert np.allclose(betaVec.sum(), 1.0)
        exactVals[ii] = cDir_exact(alpha, betaVec)
        boundVals[ii] = cDir_surrogate(alpha, betaVec)
    assert np.all(exactVals >= boundVals)
    pylab.plot(KVals, exactVals,
               'k-', linewidth=LINEWIDTH, label=labels[0])
    pylab.plot(KVals, boundVals,
               'r--', linewidth=LINEWIDTH, label=labels[1])
    index = -1

    pylab.text(KVals[index]+.25, boundVals[index],
        txtlabel, fontsize=LEGENDSIZE-8)
    pylab.xlim([0, np.max(KVals)+7.5])
    pylab.gca().set_xticks([0, 10, 20, 30, 40, 50])
    pylab.xlabel("K", fontsize=FONTSIZE)
    pylab.ylabel("cDir function", fontsize=FONTSIZE)
    pylab.tick_params(axis='both', which='major', labelsize=TICKSIZE)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号