def detail_plot(df, tlow, thigh):
hz1 = get_default(df['hz1'].values[0], -2, float)
hz2 = get_default(df['hz2'].values[0], -1, float)
color = get_default(df['teff'].values[0], 5777, float)
tlow = get_default(max(2500, tlow), 2500, int)
thigh = get_default(min(8500, thigh), 8500, int)
R = df.iloc[0]['radius']
r = [planetary_radius(mi, ri) for mi, ri in df.loc[:, ['plMass', 'plRadius']].values]
smas = df['sma'].values
max_smas = max([smai for smai in smas if isinstance(smai, (int, float)) and not np.isnan(smai)])
Rs = max(500, 500*R)
rs = [max(80, 30*ri) for ri in r]
fig, ax = plt.subplots(1, figsize=(14, 2))
ax.scatter([0], [1], s=Rs, c=color, vmin=tlow, vmax=thigh, cmap=cm.autumn)
no_sma = []
if 0 < hz1 < hz2:
x = np.linspace(hz1, hz2, 10)
y = np.linspace(0.9, 1.1, 10)
z = np.array([[xi]*10 for xi in x[::-1]]).T
plt.contourf(x, y, z, 300, alpha=0.8, cmap=cm.summer)
for i, sma in enumerate(smas):
if np.isnan(sma):
no_sma.append('{} has no SMA'.format(df['plName'].values[i]))
continue
if sma < hz1:
dist = hz1-sma
ax.scatter(sma, [1], s=rs[i], c=dist, vmin=0, vmax=hz1, cmap=cm.autumn)
elif hz1 <= sma <= hz2:
ax.scatter(sma, [1], s=rs[i], c='k', alpha=0.8)
else:
dist = sma-hz2
ax.scatter(sma, [1], s=rs[i], c=dist, vmin=hz2, vmax=max_smas, cmap=cm.winter_r)
for planet in ss_planets.keys():
s = ss_planets[planet][0]
r = 30*ss_planets[planet][1]/2.
r /= float(ss_planets['Jupiter'][1])
ax.scatter(s, [0.95], s=r*10, c='g')
ax.text(s-0.01, 0.97, planet, color='white')
ax.set_xlim(0.0, max_smas*1.2)
ax.set_ylim(0.9, 1.1)
ax.set_xlabel('Semi-major axis [AU]')
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_yticks([])
ax.spines['left'].set_visible(False)
ax.set_facecolor('black')
plt.tight_layout()
for i, text in enumerate(no_sma):
ax.text(max_smas*0.8, 1.05-i*0.02, text, color='white')
return fig_to_html(fig)
python类cm()的实例源码
def _init_plot(self, dir, var, **kwargs):
"""
Internal method used by all plotting commands
"""
#self.cla()
null = kwargs.pop('zorder', None)
#Init of the bins array if not set
bins = kwargs.pop('bins', None)
if bins is None:
bins = np.linspace(np.min(var), np.max(var), 6)
if isinstance(bins, int):
bins = np.linspace(np.min(var), np.max(var), bins)
bins = np.asarray(bins)
nbins = len(bins)
#Number of sectors
nsector = kwargs.pop('nsector', None)
if nsector is None:
nsector = 16
#Sets the colors table based on the colormap or the "colors" argument
colors = kwargs.pop('colors', None)
cmap = kwargs.pop('cmap', None)
if colors is not None:
if isinstance(colors, str):
colors = [colors]*nbins
if isinstance(colors, (tuple, list)):
if len(colors) != nbins:
raise ValueError("colors and bins must have same length")
else:
if cmap is None:
cmap = cm.jet
colors = self._colors(cmap, nbins)
#Building the angles list
angles = np.arange(0, -2*np.pi, -2*np.pi/nsector) + np.pi/2
normed = kwargs.pop('normed', False)
blowto = kwargs.pop('blowto', False)
#Set the global information dictionnary
self._info['dir'], self._info['bins'], self._info['table'] = histogram(dir, var, bins, nsector, normed, blowto)
return bins, nbins, nsector, colors, angles, kwargs
def contour(self, dir, var, **kwargs):
"""
Plot a windrose in linear mode. For each var bins, a line will be
draw on the axes, a segment between each sector (center to center).
Each line can be formated (color, width, ...) like with standard plot
pylab command.
Mandatory:
* dir : 1D array - directions the wind blows from, North centred
* var : 1D array - values of the variable to compute. Typically the wind
speeds
Optional:
* nsector: integer - number of sectors used to compute the windrose
table. If not set, nsectors=16, then each sector will be 360/16=22.5°,
and the resulting computed table will be aligned with the cardinals
points.
* bins : 1D array or integer- number of bins, or a sequence of
bins variable. If not set, bins=6, then
bins=linspace(min(var), max(var), 6)
* blowto : bool. If True, the windrose will be pi rotated,
to show where the wind blow to (usefull for pollutant rose).
* colors : string or tuple - one string color ('k' or 'black'), in this
case all bins will be plotted in this color; a tuple of matplotlib
color args (string, float, rgb, etc), different levels will be plotted
in different colors in the order specified.
* cmap : a cm Colormap instance from matplotlib.cm.
- if cmap == None and colors == None, a default Colormap is used.
others kwargs : see help(pylab.plot)
"""
bins, nbins, nsector, colors, angles, kwargs = self._init_plot(dir, var,
**kwargs)
#closing lines
angles = np.hstack((angles, angles[-1]-2*np.pi/nsector))
vals = np.hstack((self._info['table'],
np.reshape(self._info['table'][:,0],
(self._info['table'].shape[0], 1))))
offset = 0
for i in range(nbins):
val = vals[i,:] + offset
offset += vals[i, :]
zorder = ZBASE + nbins - i
patch = self.plot(angles, val, color=colors[i], zorder=zorder,
**kwargs)
self.patches_list.extend(patch)
self._update()
def contourf(self, dir, var, **kwargs):
"""
Plot a windrose in filled mode. For each var bins, a line will be
draw on the axes, a segment between each sector (center to center).
Each line can be formated (color, width, ...) like with standard plot
pylab command.
Mandatory:
* dir : 1D array - directions the wind blows from, North centred
* var : 1D array - values of the variable to compute. Typically the wind
speeds
Optional:
* nsector: integer - number of sectors used to compute the windrose
table. If not set, nsectors=16, then each sector will be 360/16=22.5°,
and the resulting computed table will be aligned with the cardinals
points.
* bins : 1D array or integer- number of bins, or a sequence of
bins variable. If not set, bins=6, then
bins=linspace(min(var), max(var), 6)
* blowto : bool. If True, the windrose will be pi rotated,
to show where the wind blow to (usefull for pollutant rose).
* colors : string or tuple - one string color ('k' or 'black'), in this
case all bins will be plotted in this color; a tuple of matplotlib
color args (string, float, rgb, etc), different levels will be plotted
in different colors in the order specified.
* cmap : a cm Colormap instance from matplotlib.cm.
- if cmap == None and colors == None, a default Colormap is used.
others kwargs : see help(pylab.plot)
"""
bins, nbins, nsector, colors, angles, kwargs = self._init_plot(dir, var,
**kwargs)
null = kwargs.pop('facecolor', None)
null = kwargs.pop('edgecolor', None)
#closing lines
angles = np.hstack((angles, angles[-1]-2*np.pi/nsector))
vals = np.hstack((self._info['table'],
np.reshape(self._info['table'][:,0],
(self._info['table'].shape[0], 1))))
offset = 0
for i in range(nbins):
val = vals[i,:] + offset
offset += vals[i, :]
zorder = ZBASE + nbins - i
xs, ys = poly_between(angles, 0, val)
patch = self.fill(xs, ys, facecolor=colors[i],
edgecolor=colors[i], zorder=zorder, **kwargs)
self.patches_list.extend(patch)
def bar(self, dir, var, **kwargs):
"""
Plot a windrose in bar mode. For each var bins and for each sector,
a colored bar will be draw on the axes.
Mandatory:
* dir : 1D array - directions the wind blows from, North centred
* var : 1D array - values of the variable to compute. Typically the wind
speeds
Optional:
* nsector: integer - number of sectors used to compute the windrose
table. If not set, nsectors=16, then each sector will be 360/16=22.5°,
and the resulting computed table will be aligned with the cardinals
points.
* bins : 1D array or integer- number of bins, or a sequence of
bins variable. If not set, bins=6 between min(var) and max(var).
* blowto : bool. If True, the windrose will be pi rotated,
to show where the wind blow to (usefull for pollutant rose).
* colors : string or tuple - one string color ('k' or 'black'), in this
case all bins will be plotted in this color; a tuple of matplotlib
color args (string, float, rgb, etc), different levels will be plotted
in different colors in the order specified.
* cmap : a cm Colormap instance from matplotlib.cm.
- if cmap == None and colors == None, a default Colormap is used.
edgecolor : string - The string color each edge bar will be plotted.
Default : no edgecolor
* opening : float - between 0.0 and 1.0, to control the space between
each sector (1.0 for no space)
"""
bins, nbins, nsector, colors, angles, kwargs = self._init_plot(dir, var,
**kwargs)
null = kwargs.pop('facecolor', None)
edgecolor = kwargs.pop('edgecolor', None)
if edgecolor is not None:
if not isinstance(edgecolor, str):
raise ValueError('edgecolor must be a string color')
opening = kwargs.pop('opening', None)
if opening is None:
opening = 0.8
dtheta = 2*np.pi/nsector
opening = dtheta*opening
for j in range(nsector):
offset = 0
for i in range(nbins):
if i > 0:
offset += self._info['table'][i-1, j]
val = self._info['table'][i, j]
zorder = ZBASE + nbins - i
patch = Rectangle((angles[j]-opening/2, offset), opening, val,
facecolor=colors[i], edgecolor=edgecolor, zorder=zorder,
**kwargs)
self.add_patch(patch)
if j == 0:
self.patches_list.append(patch)
self._update()
def box(self, dir, var, **kwargs):
"""
Plot a windrose in proportional bar mode. For each var bins and for each
sector, a colored bar will be draw on the axes.
Mandatory:
* dir : 1D array - directions the wind blows from, North centred
* var : 1D array - values of the variable to compute. Typically the wind
speeds
Optional:
* nsector: integer - number of sectors used to compute the windrose
table. If not set, nsectors=16, then each sector will be 360/16=22.5°,
and the resulting computed table will be aligned with the cardinals
points.
* bins : 1D array or integer- number of bins, or a sequence of
bins variable. If not set, bins=6 between min(var) and max(var).
* blowto : bool. If True, the windrose will be pi rotated,
to show where the wind blow to (usefull for pollutant rose).
* colors : string or tuple - one string color ('k' or 'black'), in this
case all bins will be plotted in this color; a tuple of matplotlib
color args (string, float, rgb, etc), different levels will be plotted
in different colors in the order specified.
* cmap : a cm Colormap instance from matplotlib.cm.
- if cmap == None and colors == None, a default Colormap is used.
edgecolor : string - The string color each edge bar will be plotted.
Default : no edgecolor
"""
bins, nbins, nsector, colors, angles, kwargs = self._init_plot(dir, var,
**kwargs)
null = kwargs.pop('facecolor', None)
edgecolor = kwargs.pop('edgecolor', None)
if edgecolor is not None:
if not isinstance(edgecolor, str):
raise ValueError('edgecolor must be a string color')
opening = np.linspace(0.0, np.pi/16, nbins)
for j in range(nsector):
offset = 0
for i in range(nbins):
if i > 0:
offset += self._info['table'][i-1, j]
val = self._info['table'][i, j]
zorder = ZBASE + nbins - i
patch = Rectangle((angles[j]-opening[i]/2, offset), opening[i],
val, facecolor=colors[i], edgecolor=edgecolor,
zorder=zorder, **kwargs)
self.add_patch(patch)
if j == 0:
self.patches_list.append(patch)
self._update()
def colormap(cats, mplmap='auto', categorical=None):
""" Map a series of categories to hex colors, using a matplotlib colormap
Generates both categorical and numerical colormaps.
Args:
cats (Iterable): list of categories or numerical values
mplmap (str): name of matplotlib colormap object
categorical (bool): If None
(the default) interpret data as numerical only if it can be cast to float.
If True, interpret this data as categorical. If False, cast the data to float.
Returns:
List[str]: List of hexadecimal RGB color values in the in the form ``'#000102'``
"""
# Should automatically choose the right colormaps for:
# categorical data
# sequential data (low, high important)
# diverging data (low, mid, high important)
global DEF_SEQUENTIAL
from matplotlib import cm
if hasattr(cm, 'inferno'):
DEF_SEQUENTIAL = 'inferno'
else:
DEF_SEQUENTIAL = 'BrBG'
# strip units
units = None # TODO: build a color bar with units
if hasattr(cats[0], 'magnitude'):
arr = u.array(cats)
units = arr.units
cats = arr.magnitude
is_categorical = False
else:
is_categorical = not isinstance(cats[0], (float, int))
if categorical is not None:
is_categorical = categorical
if is_categorical:
values = _map_categories_to_ints(cats)
if mplmap == 'auto':
mplmap = DEF_CATEGORICAL
else:
values = np.array(list(map(float, cats)))
if mplmap == 'auto':
mplmap = DEF_SEQUENTIAL
rgb = _cmap_to_rgb(mplmap, values)
hexcolors = [webcolors.rgb_to_hex(np.array(c)) for c in rgb]
return hexcolors
def _get_axes_unit_labels(self, unit_x, unit_y):
axes_unit_labels = ['', '']
comoving = False
hinv = False
for i, un in enumerate((unit_x, unit_y)):
unn = None
if hasattr(self.data_source, 'axis'):
if hasattr(self.ds.coordinates, "image_units"):
# This *forces* an override
unn = self.ds.coordinates.image_units[
self.data_source.axis][i]
elif hasattr(self.ds.coordinates, "default_unit_label"):
axax = getattr(self.ds.coordinates,
"%s_axis" % ("xy"[i]))[self.data_source.axis]
unn = self.ds.coordinates.default_unit_label.get(
axax, None)
if unn is not None:
axes_unit_labels[i] = r'\ \ \left('+unn+r'\right)'
continue
# Use sympy to factor h out of the unit. In this context 'un'
# is a string, so we call the Unit constructor.
expr = Unit(un, registry=self.ds.unit_registry).expr
h_expr = Unit('h', registry=self.ds.unit_registry).expr
# See http://docs.sympy.org/latest/modules/core.html#sympy.core.expr.Expr
h_power = expr.as_coeff_exponent(h_expr)[1]
# un is now the original unit, but with h factored out.
un = str(expr*h_expr**(-1*h_power))
un_unit = Unit(un, registry=self.ds.unit_registry)
cm = Unit('cm').expr
if str(un).endswith('cm') and cm not in un_unit.expr.atoms():
comoving = True
un = un[:-2]
# no length units besides code_length end in h so this is safe
if h_power == -1:
hinv = True
elif h_power != 0:
# It doesn't make sense to scale a position by anything
# other than h**-1
raise RuntimeError
if un not in ['1', 'u', 'unitary']:
if un in formatted_length_unit_names:
un = formatted_length_unit_names[un]
else:
un = Unit(un, registry=self.ds.unit_registry)
un = un.latex_representation()
if hinv:
un = un + '\,h^{-1}'
if comoving:
un = un + '\,(1+z)^{-1}'
pp = un[0]
if pp in latex_prefixes:
symbol_wo_prefix = un[1:]
if symbol_wo_prefix in prefixable_units:
un = un.replace(
pp, "{"+latex_prefixes[pp]+"}", 1)
axes_unit_labels[i] = '\ \ ('+un+')'
return axes_unit_labels
def plot_WC_layers(dayDataPath, logDataPath):
# load day data
colsDefTuple = configHolder.getCSVDayDataColumnTuple()
dailyData = np.genfromtxt(
dayDataPath, delimiter=',', skip_header=HEADER_NUM,
dtype=float, names=colsDefTuple)
WCs = dailyData[['WC1', 'WC2', 'WC3', 'WC4',
'WC5', 'WC6', 'WC7', 'WC8', 'WC9', 'WC10']]
root = dailyData['Z']
data = np.transpose([list(r) for r in WCs])
fig, ax = plt.subplots()
# plot water content in all vertical compartments
cax = ax.imshow(data, interpolation='none', cmap="Spectral", extent=[
0, len(root), -1, 0], aspect="auto")
# plot depth of root
reversedRoot = -1 * root
ax.plot(reversedRoot)
# load log data
algLog = AlgorithmLog()
logHeader = tuple(algLog.getHeader())
logData = np.genfromtxt(
logDataPath, delimiter=',', skip_header=3,
dtype=float, names=logHeader)
# plot sensor0 track
revSensor0TrackList = -1 * \
(np.array(logData['sensor0_depth']) * 10 + 5) * 0.01
ax.plot(revSensor0TrackList, color="red")
# plot sensor1 track
# revSensor1TrackList = -1 * \
# (np.array(logData['sensor1_depth']) * 10 + 5) * 0.01
# ax.plot(revSensor1TrackList, color="green")
# color bar
cbar = fig.colorbar(cax, orientation='horizontal')
ax.set_xlabel('Time(day)')
ax.set_ylabel('Soil Depth(m)')
ax.set_title("sensor0 start at {}cm and ref={}")
pathFigureOut = str(Path(
prefixOutput + r'{}_All_WC_root.png'.format(
basename(dayDataPath).split('.')[0])).resolve())
plt.savefig(pathFigureOut)
plt.clf()
individual_N2_X_Y_cluster_stage_watershed.py 文件源码
项目:CElegansBehaviour
作者: ChristophKirst
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def makeSegmentation(d, s, nbins = 256, verbose = True, sigma = 0.75, min_distance = 1):
fn = os.path.join(datadir, 'data_stage%d_%s.npy' % (s,features[0]));
dists = np.load(fn);
fn = os.path.join(datadir, 'data_stage%d_%s.npy' % (s,features[1]));
rots = np.load(fn);
ddata = np.vstack([np.log(dists[:,d]), (rots[:,d])]).T
#gmmdata = np.vstack([dists[:,j], (rots[:,j])]).T
#ddata.shape
nanids = np.logical_or(np.any(np.isnan(ddata), axis=1), np.any(np.isinf(ddata), axis=1));
ddata = ddata[~nanids,:];
#ddata.shape
imgbin = None;
img2 = smooth(ddata, nbins = [nbins, nbins], sigma = (sigma,sigma))
#img = smooth(ddata, nbins = [nbins, nbins], sigma = (1,1))
local_maxi = peak_local_max(img2, indices=False, min_distance = min_distance)
imgm2 = img2.copy();
imgm2[local_maxi] = 3 * imgm2.max();
if verbose:
imgbin = smooth(ddata, nbins = [nbins, nbins], sigma = None)
plt.figure(220); plt.clf()
plt.subplot(2,2,1)
plt.imshow(imgbin)
plt.subplot(2,2,2)
plt.imshow(img2)
plt.subplot(2,2,3);
plt.imshow(imgm2, cmap=plt.cm.jet, interpolation='nearest')
markers = ndi.label(local_maxi)[0]
labels = watershed(-img2, markers, mask = None);
print "max labels: %d" % labels.max()
if verbose:
fig, axes = plt.subplots(ncols=3, sharex=True, sharey=True, subplot_kw={'adjustable':'box-forced'})
ax0, ax1, ax2 = axes
ax0.imshow(img2, cmap=plt.cm.jet, interpolation='nearest')
ax0.set_title('PDF')
labels[imgbin==0] = 0;
labels[0,0] = -1;
ax1.imshow(labels, cmap=plt.cm.rainbow, interpolation='nearest')
ax1.set_title('Segmentation on Data')
#labelsws[0,0] = -1;
#ax2.imshow(labelsws, cmap=plt.cm.rainbow, interpolation='nearest')
#ax1.set_title('Segmentation Full')
return labels;
#classification for a specific work based on the segmentation above
individual_N2_X_Y_cluster_stage_watershed.py 文件源码
项目:CElegansBehaviour
作者: ChristophKirst
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def classifyWorm(labels, wid, d, s, nbins = 256, verbose = True):
XYwdata = XYdata[wid].copy();
w = wd.WormData(XYwdata[:,0:2], stage = XYwdata[:,-1], valid = XYwdata[:,0] != 1, label = ('x', 'y'), wid = wid);
w.replaceInvalid();
ds = w.calculateDistances(n = delays[d]+1, stage = s);
rs = w.calculateRotations(n = delays[d]+1, stage = s);
ddata = np.vstack([np.log(ds[:,-1]), (rs[:,-1])]).T
#gmmdata = np.vstack([dists[:,j], (rots[:,j])]).T
#ddata.shape
nanids = np.logical_or(np.any(np.isnan(ddata), axis=1), np.any(np.isinf(ddata), axis = 1));
ddata = ddata[~nanids,:];
#ddata.shape
pred2 =-np.ones(rs.shape[0])
pred2nn = pred2[~nanids];
pred2nn.shape
for i in range(2):
ddata[:,i] = ddata[:,i] - ddata[:,i].min();
ddata[:,i] = (ddata[:,i] / ddata[:,i].max()) * (nbins-1);
ddata = np.asarray(ddata, dtype = int);
for i in range(2):
ddata[ddata[:,i] > (nbins-1), i] = nbins-1;
for i in xrange(ddata.shape[0]):
pred2nn[i] = labels[ddata[i,0], ddata[i,1]];
pred2[~nanids] = pred2nn;
#pred2nn.max();
if verbose:
plt.figure(506); plt.clf();
w.plotTrace(ids = shiftData(pred2, delays[d]/2, nan = -1), stage = s)
if verbose > 2:
rds = w.calculateRotations(n = delays[d] + 1, stage = s);
plt.figure(510); plt.clf();
w.plotDataColor(data = shiftData(rds[:, -1], delays[d]/2, nan = -1), c = pred2, lw = 0, s = 20, stage = s, cmap = cm.rainbow)
dts = w.calculateDistances(n = delays[d] + 1, stage = s);
plt.figure(510); plt.clf();
w.plotDataColor(data = dts[:, -1], c = pred2, lw = 0, s = 20, stage = s, cmap = cm.rainbow)
plt.figure(507); plt.clf();
w.plotTrajectory(stage = s, colordata = shiftData(rds[:,-1] /(delays[d]+1) /np.pi, delays[d]/2, nan = -.1))
dist2 = w.calculateLengths(n=200);
plt.figure(511); plt.clf();
w.plotDataColor(data = dist2[:, -1], c = pred2, lw = 0, s = 20)
return pred2;
#assume classes to be 0...N
def save_contour(netD, filename, cuda=False):
#import warnings
#warnings.filterwarnings("ignore", category=FutureWarning)
#import numpy as np
#import matplotlib
#matplotlib.use('Agg')
#import matplotlib.cm as cm
#import matplotlib.mlab as mlab
#import matplotlib.pyplot as plt
matplotlib.rcParams['xtick.direction'] = 'out'
matplotlib.rcParams['ytick.direction'] = 'out'
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
# gen grid
delta = 0.1
x = np.arange(-25.0, 25.0, delta)
y = np.arange(-25.0, 25.0, delta)
X, Y = np.meshgrid(x, y)
# convert numpy array to to torch variable
(h, w) = X.shape
XY = np.concatenate((X.reshape((h*w, 1, 1, 1)), Y.reshape((h*w, 1, 1, 1))), axis=1)
input = torch.Tensor(XY)
input = Variable(input)
if cuda:
input = input.cuda()
# forward
output = netD(input)
# convert torch variable to numpy array
Z = output.data.cpu().view(-1).numpy().reshape(h, w)
# plot and save
plt.figure()
CS1 = plt.contourf(X, Y, Z)
CS2 = plt.contour(X, Y, Z, alpha=.7, colors='k')
plt.clabel(CS2, inline=1, fontsize=10, colors='k')
plt.title('Simplest default with labels')
plt.savefig(filename)
plt.close()
plotting.py 文件源码
项目:PyDataLondon29-EmbarrassinglyParallelDAWithAWSLambda
作者: SignalMedia
项目源码
文件源码
阅读 17
收藏 0
点赞 0
评论 0
def _make_plot(self):
x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]
c_is_column = com.is_hashable(c) and c in self.data.columns
# plot a colorbar only if a colormap is provided or necessary
cb = self.kwds.pop('colorbar', self.colormap or c_is_column)
# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or 'Greys'
cmap = self.plt.cm.get_cmap(cmap)
color = self.kwds.pop("color", None)
if c is not None and color is not None:
raise TypeError('Specify exactly one of `c` and `color`')
elif c is None and color is None:
c_values = self.plt.rcParams['patch.facecolor']
elif color is not None:
c_values = color
elif c_is_column:
c_values = self.data[c].values
else:
c_values = c
if self.legend and hasattr(self, 'label'):
label = self.label
else:
label = None
scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
label=label, cmap=cmap, **self.kwds)
if cb:
img = ax.collections[0]
kws = dict(ax=ax)
if self.mpl_ge_1_3_1():
kws['label'] = c if c_is_column else ''
self.fig.colorbar(img, **kws)
if label is not None:
self._add_legend_handle(scatter, label)
else:
self.legend = False
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
if len(errors_x) > 0 or len(errors_y) > 0:
err_kwds = dict(errors_x, **errors_y)
err_kwds['ecolor'] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values,
linestyle='none', **err_kwds)
def mpl_palette(name, n_colors=6, extrema=False, cycle=False):
"""Return discrete colors from a matplotlib palette.
Note that this handles the qualitative colorbrewer palettes
properly, although if you ask for more colors than a particular
qualitative palette can provide you will get fewer than you are
expecting.
Parameters
----------
name : string
Name of the palette. This should be a named matplotlib colormap.
n_colors : int
Number of discrete colors in the palette.
extrema : boolean
If True, include the extrema of the palette.
cycle : boolean
If True, return a itertools.cycle.
Returns
-------
palette : colormap or itertools.cycle
List-like object of colors as RGB tuples
"""
if name in SEABORN_PALETTES:
palette = SEABORN_PALETTES[name]
# Always return as many colors as we asked for
pal_cycle = itertools.cycle(palette)
palette = [next(pal_cycle) for _ in range(n_colors)]
elif name in dir(mpl.cm) or name[:-2] in dir(mpl.cm):
cmap = getattr(mpl.cm, name)
if extrema:
bins = np.linspace(0, 1, n_colors)
else:
bins = np.linspace(0, 1, n_colors * 2 - 1 + 2)[1:-1:2]
palette = list(map(tuple, cmap(bins)[:, :3]))
else:
raise ValueError("%s is not a valid palette name" % name)
if cycle:
return itertools.cycle(palette)
else:
return palette
def plot_embedding(embed, labels, plot_type='t-sne', title="", tsne_params={}, save_path=None,
legend=True, label_dict=None, label_order=None, legend_outside=False, alpha=0.7):
"""
Projects embedding onto two dimensions, colors according to given label
@param embed: embedding matrix
@param labels: array of labels for the rows of embed
@param title: title of plot
@param save_path: path of where to save
@param legend: bool to show legend
@param label_dict: dict that maps labels to real names (eg. {0:'rock', 1:'edm'})
"""
plt.figure()
N = len(set(labels))
colors = cm.rainbow(np.linspace(0, 1, N))
scaled_embed = scale(embed)
if plot_type == 'pca':
pca = PCA(n_components=2)
pca.fit(scaled_embed)
#note: will take a while if emebdding is large
comp1, comp2 = pca.components_
comp1, comp2 = embed.dot(comp1), embed.dot(comp2)
if plot_type == 't-sne':
tsne = TSNE(**tsne_params)
comp1, comp2 = tsne.fit_transform(scaled_embed).T
unique_labels = list(set(labels))
if label_order is not None:
unique_labels = sorted(unique_labels, key=lambda l: label_order.index(label_dict[l]))
#genre->indices of that genre (so for loop will change colors)
l_dict = {i:np.array([j for j in range(len(labels)) if labels[j] == i]) for i in unique_labels}
for i in range(N):
l = unique_labels[i]
color = colors[i]
#just use the labels of g as the labels
plt.scatter(comp1[l_dict[l]], comp2[l_dict[l]],
color=color, label=label_dict[l], alpha=alpha)
plt.title(title)
if legend:
if N >= 10 or legend_outside:
lgd = plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left')
else:
lgd = plt.legend(loc='best')
if save_path != None:
plt.savefig(save_path, bbox_extra_artists=(lgd,), bbox_inches='tight')