def scatterplot_matrix(data, attNames, **kwargs):
rows, atts = data.shape
fig, axes = plt.subplots(nrows = atts, ncols =atts, figsize=(30,30))
fig.subplots_adjust(hspace = 0.05 , wspace = 0.05)
for ax in axes.flat:
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(attNames):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
for i, j in zip(range(atts), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
评论列表
文章目录