def plot_confusion_matrix(labels, confusion_matrix, directory, name, extension):
"""
Plots the normalized confusion matrix with the target names as axis ticks.
"""
ious = calculate_iou(confusion_matrix)
size = len(labels)/5+2
fig, ax = plt.subplots(figsize=(size+2, size))
plot = ax.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues, norm=LogNorm())
# plot.set_clim(vmin=0, vmax=100)
ticks_with_iou = []
ticks_without_iou = []
tick_marks = np.arange(len(ious))
ious_for_average = []
for label, iou in zip(labels, ious):
if math.isnan(iou):
iou = 0
else:
ious_for_average.append(iou)
ticks_with_iou.append("{}: {:.2%}".format(label['name'], iou))
ticks_without_iou.append("{}".format(label['name']))
avg_iou = np.average(ious_for_average)
fig.colorbar(plot)
ax.set_xticks(tick_marks)
ax.set_xticklabels(ticks_without_iou, rotation=90)
ax.set_yticks(tick_marks)
ax.set_yticklabels(ticks_with_iou)
ax.set_title("Average IoU: {:.2%}".format(avg_iou))
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
fig.tight_layout()
fig.savefig(os.path.join(directory, '{}.{}'.format(name, extension)))
评论列表
文章目录