def plotBoundVsK(KVals=np.arange(1,50),
alpha=0.5,
gamma=10,
labels=None,
betaFunc='prior'):
if labels is None:
txtlabel = str(alpha)
labels = [None, None]
else:
txtlabel = 'alpha\n' + str(alpha)
exactVals = np.zeros(len(KVals))
boundVals = np.zeros(len(KVals))
for ii, K in enumerate(KVals):
betaVec = 1.0/(1.0 + gamma) * np.ones(K+1)
for k in range(1, K):
betaVec[k] = betaVec[k] * (1 - np.sum(betaVec[:k]))
betaVec[-1] = 1 - np.sum(betaVec[:-1])
print betaVec
assert np.allclose(betaVec.sum(), 1.0)
exactVals[ii] = cDir_exact(alpha, betaVec)
boundVals[ii] = cDir_surrogate(alpha, betaVec)
assert np.all(exactVals >= boundVals)
pylab.plot(KVals, exactVals,
'k-', linewidth=LINEWIDTH, label=labels[0])
pylab.plot(KVals, boundVals,
'r--', linewidth=LINEWIDTH, label=labels[1])
index = -1
pylab.text(KVals[index]+.25, boundVals[index],
txtlabel, fontsize=LEGENDSIZE-8)
pylab.xlim([0, np.max(KVals)+7.5])
pylab.gca().set_xticks([0, 10, 20, 30, 40, 50])
pylab.xlabel("K", fontsize=FONTSIZE)
pylab.ylabel("cDir function", fontsize=FONTSIZE)
pylab.tick_params(axis='both', which='major', labelsize=TICKSIZE)
评论列表
文章目录