def qq_plot(self, df_samp, df_clu):
"""
:param df1: interval df of enterprise a. The column name should be the enterprise id
:param df2: interval df of enterprise b. The column name should be the enterprise id
:return: slope, intercept and total fit error of fitted regression line
"""
# use longer list as reference distribution
outdir = self.output_dir + "/qq-plot"
# make output directory if not exists
if not os.path.exists(outdir):
os.makedirs(outdir)
ref = np.asarray(df_clu)
samp = np.asarray(df_samp)
ref_id = df_clu.columns
samp_id = df_samp.columns
print "Start drawing Q-Q plot using data from sample {} and cluster {}.".format(samp_id, ref_id)
# theoretical quantiles
samp_pct_x = np.asarray([percentileofscore(ref, x) for x in samp])
# sample quantiles
samp_pct_y = np.asarray([percentileofscore(samp, x) for x in samp])
# calculate the error from real percentiles to predicted percentiles: as same as mean squared error
pct_error = np.sum(np.power(samp_pct_y - samp_pct_x, 2)) / (2 * len(samp_pct_x))
# estimated linear regression model
p = np.polyfit(samp_pct_x, samp_pct_y, 1)
regr = LinearRegression()
model_x = samp_pct_x.reshape(len(samp_pct_x), 1)
model_y = samp_pct_y.reshape(len(samp_pct_y), 1)
regr.fit(model_x, model_y)
r2 = regr.score(model_x, model_y)
if p[1] > 0:
p_function = "y= {} x + {}, r-square = {}".format(p[0], p[1], r2)
elif p[1] < 0:
p_function = "y= {} x - {}, r-square = {}".format(p[0], -p[1], r2)
else:
p_function = "y= {} x, r-square = {}".format(p[0], r2)
print "The fitted linear regression model in Q-Q plot using data from enterprises {} and cluster {} is {}".format(samp_id, ref_id, p_function)
# plot q-q plot
x_ticks = np.arange(0, 100, 20)
y_ticks = np.arange(0, 100, 20)
plt.scatter(x=samp_pct_x, y=samp_pct_y, color='blue')
plt.xlim((0, 100))
plt.ylim((0, 100))
# add fit regression line
plt.plot(samp_pct_x, regr.predict(model_x), color='red', linewidth=2)
# add 45-degree reference line
plt.plot([0, 100], [0, 100], linewidth=2)
plt.text(10, 70, p_function)
plt.xticks(x_ticks, x_ticks)
plt.yticks(y_ticks, y_ticks)
plt.xlabel('cluster quantiles - id: {}'.format(ref_id))
plt.ylabel('sample quantiles - id: {}'.format(samp_id))
plt.title('{} VS {} Q-Q plot'.format(ref_id, samp_id))
outfile = "{}/enterprise-{}-VS-cluster-{}.qqplot.png".format(outdir, samp_id, ref_id)
plt.savefig(outfile)
print "Plotting Q-Q plot done! The plot is stored at {}.".format(outfile)
plt.close()
return p[0], p[1], pct_error
dataset_merging.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录