cohort.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:cohorts 作者: hammerlab 项目源码 文件源码
def plot_correlation(self, on, x_col=None, plot_type="jointplot", stat_func=pearsonr, show_stat_func=True, plot_kwargs={}, **kwargs):
        """Plot the correlation between two variables.

        Parameters
        ----------
        on : list or dict of functions or strings
            See `cohort.load.as_dataframe`
        x_col : str, optional
            If `on` is a dict, this guarantees we have the expected ordering.
        plot_type : str, optional
            Specify "jointplot", "regplot", "boxplot", or "barplot".
        stat_func : function, optional.
            Specify which function to use for the statistical test.
        show_stat_func : bool, optional
            Whether or not to show the stat_func result in the plot itself.
        plot_kwargs : dict, optional
            kwargs to pass through to plotting functions.
        """
        if plot_type not in ["boxplot", "barplot", "jointplot", "regplot"]:
            raise ValueError("Invalid plot_type %s" % plot_type)
        plot_cols, df = self.as_dataframe(on, return_cols=True, **kwargs)
        if len(plot_cols) != 2:
            raise ValueError("Must be comparing two columns, but there are %d columns" % len(plot_cols))
        for plot_col in plot_cols:
            df = filter_not_null(df, plot_col)
        if x_col is None:
            x_col = plot_cols[0]
            y_col = plot_cols[1]
        else:
            if x_col == plot_cols[0]:
                y_col = plot_cols[1]
            else:
                y_col = plot_cols[0]
        series_x = df[x_col]
        series_y = df[y_col]
        coeff, p_value = stat_func(series_x, series_y)
        if plot_type == "jointplot":
            plot = sb.jointplot(data=df, x=x_col, y=y_col,
                                stat_func=stat_func if show_stat_func else None,
                                **plot_kwargs)
        elif plot_type == "regplot":
            plot = sb.regplot(data=df, x=x_col, y=y_col,
                              **plot_kwargs)
        elif plot_type == "boxplot":
            plot = stripboxplot(data=df, x=x_col, y=y_col, **plot_kwargs)
        else:
            plot = sb.barplot(data=df, x=x_col, y=y_col, **plot_kwargs)
        return CorrelationResults(coeff=coeff, p_value=p_value, stat_func=stat_func,
                                  series_x=series_x, series_y=series_y, plot=plot)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号