def scatterColor(x0, y, w):
"""Creates scatter plot with points colored by variable.
All input arrays must have matching lengths
:param x0: x values to plot
:type x0: list
:param y: y values to plot
:type y: list
:param w: z values to plot
:returns: plot; slope and intercept of the RLM best fit line shown on the plot
.. warning:: all input arrays must have matching lengths and scalar values
.. note:: See documentation at http://statsmodels.sourceforge.net/0.6.0/generated/statsmodels.robust.robust_linear_model.RLM.html
for the RLM line
"""
import matplotlib as mpl
import matplotlib.cm as cm
import statsmodels.api as sm
from scipy.stats import linregress
cmap = plt.cm.get_cmap('RdYlBu')
norm = mpl.colors.Normalize(vmin=w.min(), vmax=w.max())
m = cm.ScalarMappable(norm=norm, cmap=cmap)
m.set_array(w)
sc = plt.scatter(x0, y, label='', color=m.to_rgba(w))
xa = sm.add_constant(x0)
est = sm.RLM(y, xa).fit()
r2 = sm.WLS(y, xa, weights=est.weights).fit().rsquared
slope = est.params[1]
x_prime = np.linspace(np.min(x0), np.max(x0), 100)[:, np.newaxis]
x_prime = sm.add_constant(x_prime)
y_hat = est.predict(x_prime)
const = est.params[0]
y2 = [i * slope + const for i in x0]
lin = linregress(x0, y)
x1 = np.arange(np.min(x0), np.max(x0), 0.1)
y1 = [i * lin[0] + lin[1] for i in x1]
y2 = [i * slope + const for i in x1]
plt.plot(x1, y1, c='g',
label='simple linear regression m = {:.2f} b = {:.0f}, r^2 = {:.2f}'.format(lin[0], lin[1], lin[2] ** 2))
plt.plot(x1, y2, c='r', label='rlm regression m = {:.2f} b = {:.0f}, r2 = {:.2f}'.format(slope, const, r2))
plt.legend()
cbar = plt.colorbar(m)
cbar.set_label('Julian Date')
return slope, const
评论列表
文章目录