def plot_forest_all_proba(y_proba_all, y_gt):
from matplotlib import pylab
N = len(y_gt)
num_tree = len(y_proba_all)
pylab.clf()
mat = np.zeros((num_tree, N))
LOGGER.info('mat.shape={}'.format(mat.shape))
for i in range(num_tree):
mat[i,:] = y_proba_all[i][(range(N), y_gt)]
pylab.matshow(mat, fignum=False, cmap='Blues', vmin=0, vmax=1.0)
pylab.grid(False)
pylab.show()
评论列表
文章目录