dataset_merging.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:time_seires_prediction_using_lstm 作者: CasiaFan 项目源码 文件源码
def qq_plot(self, df_samp, df_clu):
        """
        :param df1: interval df of enterprise a. The column name should be the enterprise id
        :param df2: interval df of enterprise b. The column name should be the enterprise id
        :return: slope, intercept and total fit error of fitted regression line
        """
        # use longer list as reference distribution
        outdir = self.output_dir + "/qq-plot"
        # make output directory if not exists
        if not os.path.exists(outdir):
            os.makedirs(outdir)
        ref = np.asarray(df_clu)
        samp = np.asarray(df_samp)
        ref_id = df_clu.columns
        samp_id = df_samp.columns
        print "Start drawing Q-Q plot using data from sample {} and cluster {}.".format(samp_id, ref_id)
        # theoretical quantiles
        samp_pct_x = np.asarray([percentileofscore(ref, x) for x in samp])
        # sample quantiles
        samp_pct_y = np.asarray([percentileofscore(samp, x) for x in samp])
        # calculate the error from real percentiles to predicted percentiles: as same as mean squared error
        pct_error = np.sum(np.power(samp_pct_y - samp_pct_x, 2)) / (2 * len(samp_pct_x))
        # estimated linear regression model
        p = np.polyfit(samp_pct_x, samp_pct_y, 1)
        regr = LinearRegression()
        model_x = samp_pct_x.reshape(len(samp_pct_x), 1)
        model_y = samp_pct_y.reshape(len(samp_pct_y), 1)
        regr.fit(model_x, model_y)
        r2 = regr.score(model_x, model_y)
        if p[1] > 0:
            p_function = "y= {} x + {}, r-square = {}".format(p[0], p[1], r2)
        elif p[1] < 0:
            p_function = "y= {} x - {}, r-square = {}".format(p[0], -p[1], r2)
        else:
            p_function = "y= {} x, r-square = {}".format(p[0], r2)
        print "The fitted linear regression model in Q-Q plot using data from enterprises {} and cluster {} is {}".format(samp_id, ref_id, p_function)
        # plot q-q plot
        x_ticks = np.arange(0, 100, 20)
        y_ticks = np.arange(0, 100, 20)
        plt.scatter(x=samp_pct_x, y=samp_pct_y, color='blue')
        plt.xlim((0, 100))
        plt.ylim((0, 100))
        # add fit regression line
        plt.plot(samp_pct_x, regr.predict(model_x), color='red', linewidth=2)
        # add 45-degree reference line
        plt.plot([0, 100], [0, 100], linewidth=2)
        plt.text(10, 70, p_function)
        plt.xticks(x_ticks, x_ticks)
        plt.yticks(y_ticks, y_ticks)
        plt.xlabel('cluster quantiles - id: {}'.format(ref_id))
        plt.ylabel('sample quantiles - id: {}'.format(samp_id))
        plt.title('{} VS {} Q-Q plot'.format(ref_id, samp_id))
        outfile = "{}/enterprise-{}-VS-cluster-{}.qqplot.png".format(outdir, samp_id, ref_id)
        plt.savefig(outfile)
        print "Plotting Q-Q plot done! The plot is stored at {}.".format(outfile)
        plt.close()
        return p[0], p[1], pct_error
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号