TestELBOPenalizesEmptyComps.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:bnpy 作者: bnpy 项目源码 文件源码
def makeFigure(**kwargs):
    Data, trueResp = makeDataAndTrueResp(**kwargs)

    kemptyVals = np.asarray([0, 1, 2, 3.])
    ELBOVals = np.zeros_like(kemptyVals, dtype=np.float)
    PointEstELBOVals = np.zeros_like(kemptyVals, dtype=np.float)

    # Iterate over the number of empty states (0, 1, 2, ...)
    for ii, kempty in enumerate(kemptyVals):
        resp = makeNewRespWithEmptyStates(trueResp, kempty)
        PointEstELBOVals[ii] = resp2ELBO_HDPTopicModel(
            Data,
            resp,
            doPointEstimate=1,
            **kwargs)
        ELBOVals[ii] = resp2ELBO_HDPTopicModel(Data, resp, **kwargs)

    # Make largest value the one with kempty=0, to make plot look good
    PointEstELBOVals -= PointEstELBOVals[0]
    ELBOVals -= ELBOVals[0]

    # Rescale so that yaxis has units on order of 1, not 0.001
    scale = np.max(np.abs(ELBOVals))
    ELBOVals /= scale
    PointEstELBOVals /= scale

    # Set buffer-space for defining plotable area
    xB = 0.25
    B = 0.19  # big buffer for sides where we will put text labels
    b = 0.01  # small buffer for other sides
    TICKSIZE = 30
    FONTSIZE = 40
    LEGENDSIZE = 30
    LINEWIDTH = 4

    # Plot the results
    figH = pylab.figure(figsize=(9.1, 6))
    axH = pylab.subplot(111)
    axH.set_position([xB, B, (1 - xB - b), (1 - B - b)])

    plotargs = dict(markersize=20, linewidth=LINEWIDTH)
    pylab.plot(kemptyVals, PointEstELBOVals, 'v-', label='HDP point est',
               color='b', markeredgecolor='b',
               **plotargs)
    pylab.plot(kemptyVals, np.zeros_like(kemptyVals), 's:', label='HDP exact',
               color='g', markeredgecolor='g',
               **plotargs)
    pylab.plot(kemptyVals, ELBOVals, 'o--', label='HDP surrogate',
               color='r', markeredgecolor='r',
               **plotargs)

    pylab.xlabel('num. empty topics', fontsize=FONTSIZE)
    pylab.ylabel('change in ELBO', fontsize=FONTSIZE)
    xB = 0.25
    pylab.xlim([-xB, kemptyVals[-1] + xB])
    pylab.xticks(kemptyVals)
    pylab.yticks([-1, 0, 1])

    axH = pylab.gca()
    axH.tick_params(axis='both', which='major', labelsize=TICKSIZE)
    legH = pylab.legend(loc='upper left', prop={'size': LEGENDSIZE})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号