model_evaluator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:johnson-county-ddj-public 作者: dssg 项目源码 文件源码
def plot_deviations(self, feature_column):
        """ Plots deviations from expected distributions of features within each
        predicted class.

        :param feature_column: name of the column on which to plot distributions
        :type feature_column: str
        :returns: heatmap of deviations
        :rtype: matplotlib figure
        """
        expected_proportions = self.get_distribution_by_class(
            feature_column, self.model['labelling'][0], True)

        observed_proportions = self.get_distribution_by_class(
            feature_column, 'y_pred', True)

        observed_values = self.get_distribution_by_class(
            feature_column, 'y_pred', False)

        proportion_deviation = ((observed_proportions - expected_proportions) / 
                                expected_proportions)

        deviation_plot = sns.heatmap(proportion_deviation, cmap = 'RdBu_r',
                                     vmin = -1, vmax = 1,
                                     annot = observed_values, fmt = 'g')

        deviation_plot.set(xlabel = feature_column, ylabel = 'predicted class',
                           yticklabels = reversed(self.labels))

        return(deviation_plot)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号