def plot_2d(dataset, nbins, data=None, extra=None):
if data is None:
data = np.loadtxt('experiments/uci/data/splits/{0}_all.csv'.format(dataset), skiprows=1, delimiter=',')[:,-2:]
with sns.axes_style('white'):
plt.rc('font', weight='bold')
plt.rc('grid', lw=2)
plt.rc('lines', lw=2)
rows, cols = nbins
im = np.zeros(nbins)
for i in xrange(rows):
for j in xrange(cols):
im[i,j] = ((data[:,0] == i) & (data[:,1] == j)).sum()
plt.imshow(im, cmap='gray_r', interpolation='none')
if extra is not None:
dataset += extra
plt.savefig('plots/marginals-{0}.pdf'.format(dataset.replace('_','-')), bbox_inches='tight')
plt.clf()
plt.close()
评论列表
文章目录