def plot_mondrian_kernel_vs_mondrian_forest(lifetime_max, res):
""" Plots training and test set error of Mondrian kernel and Mondrian forest based on the same set of M Mondrian samples.
This procedure takes as input a dictionary res, returned by the evaluate_all_lifetimes procedure in mondrian_kernel.py.
"""
times = res['times']
forest_train = res['forest_train']
forest_test = res['forest_test']
kernel_train = res['kernel_train']
kernel_test = res['kernel_test']
# set up test error plot
fig = plt.figure(figsize=(7, 4))
ax = fig.add_subplot('111')
remove_chartjunk(ax)
ax.set_xlabel('lifetime $\lambda$')
ax.set_ylabel('relative error [\%]')
ax.yaxis.grid(b=True, which='major', linestyle='dotted', lw=0.5, color='black', alpha=0.3)
ax.set_xscale('log')
ax.set_xlim((1e-8, lifetime_max))
ax.set_ylim((0, 25))
rasterized = False
ax.plot(times, forest_test, drawstyle="steps-post", ls='-', lw=2, color=tableau20(6), label='"M. forest" (test)', rasterized=rasterized)
ax.plot(times, forest_train, drawstyle="steps-post", ls='-', color=tableau20(7), label='"M. forest" (train)', rasterized=rasterized)
ax.plot(times, kernel_test, drawstyle="steps-post", ls='-', lw=2, color=tableau20(4), label='M. kernel (test)', rasterized=rasterized)
ax.plot(times, kernel_train, drawstyle="steps-post", ls='-', color=tableau20(5), label='M. kernel (train)', rasterized=rasterized)
ax.legend(bbox_to_anchor=[1.15, 1.05], frameon=False)
experiment_3_mondrian_kernel_vs_forest.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录