def plot_deviations(self, feature_column):
""" Plots deviations from expected distributions of features within each
predicted class.
:param feature_column: name of the column on which to plot distributions
:type feature_column: str
:returns: heatmap of deviations
:rtype: matplotlib figure
"""
expected_proportions = self.get_distribution_by_class(
feature_column, self.model['labelling'][0], True)
observed_proportions = self.get_distribution_by_class(
feature_column, 'y_pred', True)
observed_values = self.get_distribution_by_class(
feature_column, 'y_pred', False)
proportion_deviation = ((observed_proportions - expected_proportions) /
expected_proportions)
deviation_plot = sns.heatmap(proportion_deviation, cmap = 'RdBu_r',
vmin = -1, vmax = 1,
annot = observed_values, fmt = 'g')
deviation_plot.set(xlabel = feature_column, ylabel = 'predicted class',
yticklabels = reversed(self.labels))
return(deviation_plot)
评论列表
文章目录