def show_C12(C12, qz_ind=0, qr_ind=0, N1=None,N2=None, vmin=None, vmax=None, title=False):
g12_num = qz_ind * num_qr + qr_ind
if N1 is None:
N1=0
if N2 is None:
N2=Nming
if vmin is None:
vmin = 1
if vmax is None:
vmax = 1.02
data = g12b[N1:N2,N1:N2,g12_num]
fig, ax = plt.subplots()
im=ax.imshow( data, origin='lower' , cmap='viridis',
norm= LogNorm( vmin, vmax ),
extent=[0, data.shape[0]*timeperframe, 0, data.shape[0]*timeperframe ] )
#ax.set_title('%s-%s frames--Qth= %s'%(N1,N2,g12_num))
if title:
ax.set_title('%s-%s frames--Qz= %s--Qr= %s'%(N1,N2, qz_center[qz_ind], qr_center[qr_ind] ))
ax.set_xlabel( r'$t_1$ $(s)$', fontsize = 18)
ax.set_ylabel( r'$t_2$ $(s)$', fontsize = 18)
fig.colorbar(im)
#plt.show()
python类LogNorm()的实例源码
def prettyPlot(samps, dat, hid):
fig, ax = plt.subplots()
sz = 18
plt.rc('xtick', labelsize=sz)
plt.rc('ytick', labelsize=sz)
ax.set_xticklabels([1]+samps, fontsize=sz)
ax.set_yticklabels([1]+samps[::-1], fontsize=sz)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
ax.set_xlabel('Number of Experts', fontsize=sz+2)
ax.set_ylabel('Minibatch Size', fontsize=sz+2)
ax.set_title('MOE Cell Speedup Factor', fontsize=sz+4)
#Show cell values
for i in range(len(samps)):
for j in range(len(samps)):
ax.text(i, j, str(dat[i,j])[:4], ha='center', va='center', fontsize=sz, color='white')
plt.imshow(cellTimes, cmap='viridis', norm=colors.LogNorm(vmin=cellTimes.min(), vmax=cellTimes.max()))
plt.show()
def mplot_function(f, vmin, vmax, logscale):
mesh = f.function_space().mesh()
if (mesh.geometry().dim() != 2):
raise AttributeError('Mesh must be 2D')
# DG0 cellwise function
if f.vector().size() == mesh.num_cells():
C = f.vector().get_local()
if logscale:
return plt.tripcolor(mesh2triang(mesh), C, vmin=vmin, vmax=vmax, norm=cls.LogNorm() )
else:
return plt.tripcolor(mesh2triang(mesh), C, vmin=vmin, vmax=vmax)
# Scalar function, interpolated to vertices
elif f.value_rank() == 0:
C = f.compute_vertex_values(mesh)
if logscale:
return plt.tripcolor(mesh2triang(mesh), C, vmin=vmin, vmax=vmax, norm=cls.LogNorm() )
else:
return plt.tripcolor(mesh2triang(mesh), C, shading='gouraud', vmin=vmin, vmax=vmax)
# Vector function, interpolated to vertices
elif f.value_rank() == 1:
w0 = f.compute_vertex_values(mesh)
if (len(w0) != 2*mesh.num_vertices()):
raise AttributeError('Vector field must be 2D')
X = mesh.coordinates()[:, 0]
Y = mesh.coordinates()[:, 1]
U = w0[:mesh.num_vertices()]
V = w0[mesh.num_vertices():]
C = np.sqrt(U*U+V*V)
return plt.quiver(X,Y,U,V, C, units='x', headaxislength=7, headwidth=7, headlength=7, scale=4, pivot='middle')
# Plot a generic dolfin object (if supported)
def _get_cmap_data(data, kwargs):
"""Get normalized values to be used with a colormap.
Parameters
----------
data : array_like
cmap_min : Optional[float] or "min"
By default 0. If "min", minimum value of the data.
cmap_max : Optional[float]
By default, maximum value of the data
cmap_normalize : str or colors.Normalize
Returns
-------
normalizer : colors.Normalize
normalized_data : array_like
"""
norm = kwargs.pop("cmap_normalize", None)
if norm == "log":
cmap_max = kwargs.pop("cmap_max", data.max())
cmap_min = kwargs.pop("cmap_min", data[data > 0].min())
norm = colors.LogNorm(cmap_min, cmap_max)
elif not norm:
cmap_max = kwargs.pop("cmap_max", data.max())
cmap_min = kwargs.pop("cmap_min", 0)
if cmap_min == "min":
cmap_min = data.min()
norm = colors.Normalize(cmap_min, cmap_max, clip=True)
return norm, norm(data)
def getNorm(self):
mx = 10**self.scale
if self.normScale == 'linear':
mn = 0.0
norm = pltLinNorm(mn,mx)
elif self.normScale == 'log':
mn = 1e-10
norm = pltLogNorm(mn,mx)
else:
raise Exception('Invalid norm %s.' % norm)
return norm
def plot_2d_hist(x1, x2, bins=10):
plt.hist2d(x1, x2, bins=10, norm=LogNorm())
plt.colorbar()
plt.show()
def plot(self):
# Prepare the data
x = self.locs[self.field_x]
y = self.locs[self.field_y]
valid = (np.isfinite(x) & np.isfinite(y))
x = x[valid]
y = y[valid]
# Prepare the figure
self.figure.clear()
# self.canvas.figure = self.figure
axes = self.figure.add_subplot(111)
# Start hist2 version
bins_x = lib.calculate_optimal_bins(x, 1000)
bins_y = lib.calculate_optimal_bins(y, 1000)
counts, x_edges, y_edges, image = axes.hist2d(x, y, bins=[bins_x, bins_y], norm=LogNorm())
x_range = x.ptp()
axes.set_xlim([bins_x[0] - 0.05*x_range, x.max() + 0.05*x_range])
y_range = y.ptp()
axes.set_ylim([bins_y[0] - 0.05*y_range, y.max() + 0.05*y_range])
self.figure.colorbar(image, ax=axes)
axes.grid(False)
axes.get_xaxis().set_label_text(self.field_x)
axes.get_yaxis().set_label_text(self.field_y)
self.selector = RectangleSelector(axes, self.on_rect_select, useblit=False,
rectprops=dict(facecolor='green', alpha=0.2, fill=True))
self.canvas.draw()
def showLog(im, cmap='jet'):
"Displays log of the real image with correct colorbar."
f = plt.figure();
i = plt.imshow(im, norm=col.LogNorm(), cmap=cmap)
f.colorbar(i)
return f,i
cb_dap_plates_mpl5_presentation.py 文件源码
项目:bates_galaxies_lab
作者: aleksds
项目源码
文件源码
阅读 25
收藏 0
点赞 0
评论 0
def daplot(quantity, qmin, qmax):
ax.set_xlim(0, size)
ax.set_ylim(0, size)
ax.set_xticklabels(())
ax.set_yticklabels(())
plt.imshow(quantity, origin='lower', interpolation='nearest',
norm=colors.LogNorm(vmin=qmin, vmax=qmax), cmap=cm.coolwarm)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=14)
# minimum and maximum emission-line fluxes for plot ranges
def daplot(quantity, qmin, qmax):
ax.set_xlim(0, size)
ax.set_ylim(0, size)
ax.set_xticklabels(())
ax.set_yticklabels(())
plt.imshow(quantity, origin='lower', interpolation='nearest',
norm=colors.LogNorm(vmin=qmin, vmax=qmax), cmap=cm.coolwarm)
plt.colorbar()
# minimum and maximum emission-line fluxes for plot ranges
def daplot(quantity, qmin, qmax):
ax.set_xlim(0, size)
ax.set_ylim(0, size)
ax.set_xticklabels(())
ax.set_yticklabels(())
plt.imshow(quantity, origin='lower', interpolation='nearest',
norm=colors.LogNorm(vmin=qmin, vmax=qmax), cmap=cm.coolwarm)
plt.colorbar()
# minimum and maximum emission-line fluxes for plot ranges
def plot_price(smoothed_prices):
plot_over_map(10**(smoothed_prices - 3), norm=LogNorm(1.5e2, 1e3))
cb = plt.colorbar(fraction=0.03, ticks=sp.linspace(2e2, 1e3, 9), format=FormatStrFormatter(u'£%dk'))
cb.set_label(u'price paid (£1000s)')
plt.title('2015 Average Price Paid')
plt.gcf().set_size_inches(36, 36)
plt.gcf().savefig(os.path.join(OUTPUT_PATH, 'price_paid.png'), bbox_inches='tight')
def plot_relative_price(relative_prices):
plot_over_map(10**relative_prices, norm=LogNorm(0.5, 2))
cb = plt.colorbar(fraction=0.03, ticks=sp.linspace(0.5, 2, 4), format=FormatStrFormatter('x%.2f'))
cb.set_label('fraction of average price paid for commute time')
plt.title('Price relative to commute')
plt.gcf().set_size_inches(36, 36)
plt.gcf().savefig(os.path.join(OUTPUT_PATH, 'relative_price.png'), bbox_inches='tight')
def plot(outfn, a, genomeSize, base2chr, _windowSize, dpi=300, ext="svg"):
"""Save contact plot"""
def format_fn(tick_val, tick_pos):
"""Mark axis ticks with chromosome names"""
if int(tick_val) in base2chr:
return base2chr[int(tick_val)]
else:
sys.stderr.write("[WARNING] %s not in ticks!\n"%tick_val)
return ''
# invert base2chr
base2chr = {genomeSize-b: c for b, c in base2chr.iteritems()}
# start figure
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title("Contact intensity plot [%sk]"%(_windowSize/1000,))
# label Y axis with chromosome names
if len(base2chr)<50:
ax.yaxis.set_major_formatter(FuncFormatter(format_fn))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.yticks(base2chr.keys())
ax.set_ylabel("Chromosomes")
else:
ax.set_ylabel("Genome position")
# label axes
ax.set_xlabel("Genome position")
plt.imshow(a+1, cmap=cm.hot, norm=LogNorm(), extent=(0, genomeSize, 0, genomeSize))#
plt.colorbar()
# save
fig.savefig("%s.%s"%(outfn,ext), dpi=dpi, papertype="a4")
def plot(outfn, a, genomeSize, base2chr, _windowSize, dpi=300, ext="svg"):
"""Save contact plot"""
def format_fn(tick_val, tick_pos):
"""Mark axis ticks with chromosome names"""
if int(tick_val) in base2chr:
return base2chr[int(tick_val)]
else:
sys.stderr.write("[WARNING] %s not in ticks!\n"%tick_val)
return ''
# invert base2chr
base2chr = {genomeSize-b: c for b, c in base2chr.iteritems()}
# start figure
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title("Contact intensity plot [%sk]"%(_windowSize/1000,))
# label Y axis with chromosome names
if len(base2chr)<50:
ax.yaxis.set_major_formatter(FuncFormatter(format_fn))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.yticks(base2chr.keys())
ax.set_ylabel("Chromosomes")
else:
ax.set_ylabel("Genome position")
# label axes
ax.set_xlabel("Genome position")
plt.imshow(a+1, cmap=cm.hot, norm=LogNorm(), extent=(0, genomeSize, 0, genomeSize))#
plt.colorbar()
# save
fig.savefig("%s.%s"%(outfn,ext), dpi=dpi, papertype="a4")
def plot_confusion_matrix(labels, confusion_matrix, directory, name, extension):
"""
Plots the normalized confusion matrix with the target names as axis ticks.
"""
ious = calculate_iou(confusion_matrix)
size = len(labels)/5+2
fig, ax = plt.subplots(figsize=(size+2, size))
plot = ax.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues, norm=LogNorm())
# plot.set_clim(vmin=0, vmax=100)
ticks_with_iou = []
ticks_without_iou = []
tick_marks = np.arange(len(ious))
ious_for_average = []
for label, iou in zip(labels, ious):
if math.isnan(iou):
iou = 0
else:
ious_for_average.append(iou)
ticks_with_iou.append("{}: {:.2%}".format(label['name'], iou))
ticks_without_iou.append("{}".format(label['name']))
avg_iou = np.average(ious_for_average)
fig.colorbar(plot)
ax.set_xticks(tick_marks)
ax.set_xticklabels(ticks_without_iou, rotation=90)
ax.set_yticks(tick_marks)
ax.set_yticklabels(ticks_with_iou)
ax.set_title("Average IoU: {:.2%}".format(avg_iou))
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
fig.tight_layout()
fig.savefig(os.path.join(directory, '{}.{}'.format(name, extension)))
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=None,
zmin=1):
"""Print and plot the confusion matrix for the intent classification.
Normalization can be applied by setting `normalize=True`."""
import numpy as np
zmax = cm.max()
plt.clf()
plt.imshow(cm, interpolation='nearest', cmap=cmap if cmap else plt.cm.Blues,
aspect='auto', norm=LogNorm(vmin=zmin, vmax=zmax))
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
logger.info("Normalized confusion matrix: \n{}".format(cm))
else:
logger.info("Confusion matrix, without normalization: \n{}".format(cm))
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
def create_plot_correlation(params,plabs,col='red',mark='.',num=[]):
if ( len(num) < 1 ):
n = range(0,len(params))
else:
n = num
plt.figure(1,figsize=(4*len(n),4*len(n)))
gs = gridspec.GridSpec(nrows=len(n),ncols=len(n))
o = 0
for i in n:
p = 0
for j in n:
if ( j < i ):
plt.subplot(gs[o*len(n)+p])
plt.tick_params( axis='y',which='both',direction='in',labelleft='off')
plt.tick_params( axis='x',which='both',direction='in',labelbottom='off')
plt.ticklabel_format(useOffset=False, axis='both')
if ( j == n[0] ):
plt.ylabel(plabs[i],fontsize=25)
elif ( j == i - 1 ):
plt.tick_params( axis='y',which='both',direction='in',labelleft='off')
plt.tick_params( axis='x',which='both',direction='in',labelbottom='off')
else:
plt.tick_params( axis='y',which='both',direction='in',labelleft='off')
plt.tick_params( axis='x',which='both',direction='in',labelbottom='off')
if ( i == n[len(n)-1]):
plt.xlabel(plabs[j],fontsize=25)
else:
plt.tick_params( axis='y',which='both',direction='in',labelleft='off')
plt.tick_params( axis='x',which='both',direction='in',labelbottom='off')
plt.hist2d(params[j],params[i],bins=100,norm=LogNorm())
p = p + 1
o = o + 1
fname = outdir+'/'+star+'_correlations.pdf'
print 'Creating ', fname
plt.savefig(fname,format='pdf',bbox_inches='tight')
plt.close()
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=None,
zmin=1):
"""Print and plot the confusion matrix for the intent classification.
Normalization can be applied by setting `normalize=True`."""
import numpy as np
zmax = cm.max()
plt.clf()
plt.imshow(cm, interpolation='nearest', cmap=cmap if cmap else plt.cm.Blues,
aspect='auto', norm=LogNorm(vmin=zmin, vmax=zmax))
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=90)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
logger.info("Normalized confusion matrix: \n{}".format(cm))
else:
logger.info("Confusion matrix, without normalization: \n{}".format(cm))
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
def _save_2d_error_plot(self, detector, xlist, ylist, elist, x_axis_label, y_axis_label, z_axis_label):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import colors
# configure logscale on X and Y axis (both for positive and negative numbers)
fig, ax = plt.subplots(1, 1)
if PlotAxis.x in self.axis_with_logscale:
plt.xscale('symlog')
if PlotAxis.y in self.axis_with_logscale:
plt.yscale('symlog')
if PlotAxis.z in self.axis_with_logscale:
norm = colors.LogNorm(vmin=elist[elist > 0].min(), vmax=elist.max())
else:
norm = colors.Normalize(vmin=elist.min(), vmax=elist.max())
plt.xlabel(x_axis_label)
plt.ylabel(y_axis_label)
mesh = plt.pcolormesh(xlist, ylist, elist.clip(0.0), cmap=self.colormap, norm=norm)
cbar = fig.colorbar(mesh)
cbar.set_label(label=z_axis_label, rotation=270, verticalalignment='bottom')
base_name, _ = os.path.splitext(self.plot_filename)
plt.savefig(base_name + "_error.png")
plt.close()
def clicks_heatmap_first_occ():
print 'loading'
db = MySQLDatabase(DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME)
db_worker_view = db.get_work_view()
coords = db_worker_view.retrieve_all_links_coords_clicks_first_occ()
print 'coord loaded'
links = {}
x = []
y = []
values = []
for link in coords.values():
x_normed = float(link['x'])/float(1920)
y_normed = float(link['y'])/float(link['page_length'])
if x_normed <=1.0 and y_normed <=1.0:
x.append(x_normed)
y.append(y_normed)
values.append(float(link['counts']))
heatmap, xedges, yedges = np.histogram2d(x, y, bins=100, weights=values)
extent = [xedges[0], xedges[-1], yedges[-1], yedges[0] ]
fig_size = (2.4, 2)
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap , extent=extent, origin='upper', norm=LogNorm(), cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Clicks Heatmap Log Normalized")
plt.show()
plt.savefig('output/clicks_heatmap_lognormed_self_loop_first_occ.pdf')
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap , extent=extent, origin='upper', norm=Normalize(), cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Clicks Heatmap Normalized")
plt.show()
plt.savefig('output/clicks_heatmap_normed_self_loop_first_occ.pdf')
print "done"
def clicks_heatmap_total():
print 'loading'
db = MySQLDatabase(DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME)
db_worker_view = db.get_work_view()
coords = db_worker_view.retrieve_all_links_coords_clicks()
print 'coord loaded'
links = {}
x = []
y = []
values = []
for coord in coords:
x_normed = float(coord['x'])/float(1920)
y_normed = float(coord['y'])/float(coord['page_length'])
if x_normed <=1.0 and y_normed <=1.0:
x.append(x_normed)
y.append(y_normed)
values.append(float(coord['counts']))
heatmap, xedges, yedges = np.histogram2d(x, y, bins=100, weights=values)
extent = [xedges[0], xedges[-1], yedges[-1], yedges[0] ]
fig_size = (2.4, 2)
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap , extent=extent, origin='upper', norm=LogNorm(), cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Clicks Heatmap Log Normalized")
plt.show()
plt.savefig('output/clicks_heatmap_lognormed_self_loop_total.pdf')
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap , extent=extent, origin='upper', norm=Normalize(), cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Clicks Heatmap Normalized")
plt.show()
plt.savefig('output/clicks_heatmap_normed_self_loop_total.pdf')
print "done"
def links_heatmap():
#http://stackoverflow.com/questions/2369492/generate-a-heatmap-in-matplotlib-using-a-scatter-data-set
# Get URLs from a text file, remove white space.
print 'loading'
db = MySQLDatabase(DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME)
db_worker_view = db.get_work_view()
coords = db_worker_view.retrieve_all_links_coords()
print 'coord loaded'
x=[]
y=[]
page_lenghts = db_worker_view.retrieve_all_page_lengths()
print 'lenghts loaded'
for coord in coords:
x_normed = float(coord['x'])/float(1920)
y_normed = float(coord['y'])/float(page_lenghts[coord['source_article_id']])
if x_normed <=1.0 and y_normed <=1.0:
x.append(x_normed)
y.append(y_normed)
heatmap, xedges, yedges = np.histogram2d(x, y, bins=100)
extent = [xedges[0], xedges[-1], yedges[-1], yedges[0]]
fig_size = (2.4, 2)
#fig_size = (3.5, 3)
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap, extent=extent, origin='upper', norm=LogNorm(), cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Links Heatmap Log Normalized")
plt.show()
plt.savefig('output/links_heatmap_lognormed_self_loop.pdf')
plt.clf()
plt.figure(figsize=fig_size)
plt.grid(True)
plt.imshow(heatmap , extent=extent, origin='upper', norm=Normalize(),cmap=plt.get_cmap('jet'))
plt.colorbar()
#plt.title("Links Heatmap Normalized")
plt.show()
plt.savefig('output/links_heatmap_normed_self_loop.pdf')
print "done"
def plot(self,keys=None,burn=1000):
if keys is None:
keys=self.names0
k=0
#plm=putil.Plm1(rows=2,cols=2,xmulti=True,ymulti=True,slabel=False)
for i in range(len(keys)):
for j in range(len(keys)):
k=k+1
if i==j:
x=self.chain[keys[i]][burn:]
plt.subplot(len(keys),len(keys),k)
#sig=np.std(self.chain[keys[i]][burn:])
xmean=np.mean(x)
nbins=np.max([20,x.size/1000])
plt.hist(x,bins=nbins,normed=True,histtype='step')
plt.axvline(np.mean(self.chain[keys[i]][burn:]),lw=2.0,color='g')
if i == (len(keys)-1):
plt.xlabel(self.descr[keys[i]][3])
plt.text(0.05,0.7,stat_text(self.chain[keys[i]][burn:]),transform=plt.gca().transAxes)
plt.gca().xaxis.set_major_locator(MaxNLocator(3, prune="both"))
plt.gca().yaxis.set_major_locator(MaxNLocator(3, prune="both"))
plt.gca().set_yticklabels([])
else:
if i > j:
plt.subplot(len(keys),len(keys),k)
x=self.chain[keys[j]][burn:]
y=self.chain[keys[i]][burn:]
nbins=np.max([32,x.size/1000])
plt.hist2d(x,y,bins=[nbins,nbins],norm=LogNorm())
plt.axvline(np.mean(self.chain[keys[j]][burn:]),lw=2.0)
plt.axhline(np.mean(self.chain[keys[i]][burn:]),lw=2.0)
if j == 0:
plt.ylabel(self.descr[keys[i]][3])
else:
plt.gca().set_yticklabels([])
if i == (len(keys)-1):
plt.xlabel(self.descr[keys[j]][3])
else:
plt.gca().set_xticklabels([])
plt.gca().xaxis.set_major_locator(MaxNLocator(3, prune="both"))
plt.gca().yaxis.set_major_locator(MaxNLocator(3, prune="both"))
#plt.colorbar(pad=0.0,fraction=0.1)
plt.subplots_adjust(hspace=0.15,wspace=0.1)
def plotFace2D(
mesh2D,
j, real_or_imag='real', ax=None, range_x=None,
range_y=None, sample_grid=None,
logScale=True, clim=None, mirror=False, pcolorOpts=None,
cbar=True
):
"""
Create a streamplot (a slice in the theta direction) of a face vector
:param discretize.CylMesh mesh2D: cylindrically symmetric mesh
:param np.ndarray j: face vector (x, z components)
:param str real_or_imag: real or imaginary component
:param matplotlib.axes ax: axes
:param numpy.ndarray range_x: x-extent over which we want to plot
:param numpy.ndarray range_y: y-extent over which we want to plot
:param numpy.ndarray sample_grid: x, y spacings at which to re-sample the plotting grid
:param bool logScale: use a log scale for the colorbar?
"""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
if len(j) == mesh2D.nF:
vType = 'F'
elif len(j) == mesh2D.nC*2:
vType = 'CCv'
if pcolorOpts is None:
pcolorOpts = {}
if logScale is True:
pcolorOpts['norm'] = LogNorm()
else:
pcolorOpts = {}
f = mesh2D.plotImage(
getattr(j, real_or_imag),
view='vec', vType=vType, ax=ax,
range_x=range_x, range_y=range_y, sample_grid=sample_grid,
mirror=mirror,
pcolorOpts=pcolorOpts,
)
out = (ax,)
if cbar is True:
cb = plt.colorbar(f[0], ax=ax)
out += (cbar,)
if clim is not None:
cb.set_clim(clim)
cb.update_ticks()
return out
def plotEdge2D(
mesh2D,
h, real_or_imag='real', ax=None, range_x=None,
range_y=None, sample_grid=None,
logScale=True, clim=None, mirror=False, pcolorOpts=None
):
"""
Create a pcolor plot (a slice in the theta direction) of an edge vector
:param discretize.CylMesh mesh2D: cylindrically symmetric mesh
:param np.ndarray h: edge vector (y components)
:param str real_or_imag: real or imaginary component
:param matplotlib.axes ax: axes
:param numpy.ndarray range_x: x-extent over which we want to plot
:param numpy.ndarray range_y: y-extent over which we want to plot
:param numpy.ndarray sample_grid: x, y spacings at which to re-sample the plotting grid
:param bool logScale: use a log scale for the colorbar?
"""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
if len(h) == mesh2D.nE:
vType = 'E'
elif len(h) == mesh2D.nC:
vType = 'CC'
elif len(h) == 2*mesh2D.nC:
vType = 'CCv'
if logScale is True:
pcolorOpts['norm'] = LogNorm()
else:
pcolorOpts = {}
cb = plt.colorbar(
mesh2D.plotImage(
getattr(h, real_or_imag),
view='real', vType=vType, ax=ax,
range_x=range_x, range_y=range_y, sample_grid=sample_grid,
mirror=mirror,
pcolorOpts=pcolorOpts,
)[0], ax=ax
)
if clim is not None:
cb.set_clim(clim)
return ax, cb
def save_png_series( imgs, ROI=None, logs=True, outDir=None, uid=None,vmin=None, vmax=None,cmap='viridis',dpi=100):
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
"""
save a series of images in a format of png
Parameters
----------
imgs : array
image data array for the movie
dimensions are: [num_img][num_rows][num_cols]
ROI: e.g. xs,xe,ys,ye = vert #x_start, x_end, y_start,y_end
outDir: the output path
vmin/vmax: for image contrast
cmap: the color for plot
dpi: resolution
Returns
-------
save png files
"""
if uid==None:
uid='uid'
num_frame=0
for img in imgs:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if ROI is None:
i0=img
asp =1.0
else:
i0=select_regoin(img, ROI, keep_shape=False,)
xs,xe,ys,ye = ROI
asp = (ye-ys)/float( xe - xs )
ax.set_aspect('equal')
if not logs:
im=ax.imshow(i0, origin='lower' ,cmap=cmap,interpolation="nearest", vmin=vmin,vmax=vmax) #vmin=0,vmax=1,
else:
im=ax.imshow(i0, origin='lower' ,cmap=cmap,
interpolation="nearest" , norm=LogNorm(vmin, vmax))
#ttl = ax.text(.75, .2, '', transform = ax.transAxes, va='center', color='white', fontsize=18)
#fig.set_size_inches( [5., 5 * asp] )
#plt.tight_layout()
fname = outDir + 'uid_%s-frame-%s.png'%(uid,num_frame )
num_frame +=1
plt.savefig( fname, dpi=None )
def spectrogram(self,
ax=None,
freq_range=None,
dB_thresh=35,
derivative=True,
colormap='gray',
compensated=True):
"""Plots a spectrogram, requires matplotlib
ax - axis on which to plot
freq_range - a tuple of frequencies, eg (300, 8000)
dB_thresh - noise floor threshold value, increase to suppress noise,
decrease to improve detail
derivative - if True, plots the spectral derivative, SAP style
colormap - colormap to use, good values: 'inferno', 'gray'
compensated - if True, centers the displayed window around the center
of the short FFT. If False, the window always starts
at the begining of data window. Both methods are equivalent
when n_overlap = 0 and the data window length is the full NFFT.
Returns an axis object
"""
if compensated:
data_overlap = self._noverlap + self._data_in_window - self._NFFT
if data_overlap < 0:
print('warning: spectrogram does not fully cover the data')
data_overlap = 0
comp = (data_overlap / 2) / self._rate
else:
comp = 0
from matplotlib import colors
if ax is None:
import matplotlib.pyplot as plt
ax = plt.gca()
if derivative:
pxx, f, t = self.max_spec_derivative(freq_range=freq_range)
thresh = value_from_dB(dB_thresh, np.max(pxx))
ax.pcolorfast(t + comp,
f,
pxx,
cmap=colormap,
norm=colors.SymLogNorm(linthresh=thresh))
else:
pxx, f, t = self.power(freq_range)
thresh = value_from_dB(dB_thresh, np.max(pxx))
ax.pcolorfast(t + comp,
f,
pxx,
cmap=colormap,
norm=colors.LogNorm(vmin=thresh))
return ax
def plot_img(ax, data, **kwargs):
"""plot an image using imshow, pcolor, pcolormesh
"""
assert ax is not None, "missing axis argument 'ax'"
vmin = kwargs['vmin']
vmax = kwargs['vmax']
cmap = kwargs['cmap']
title = kwargs['title']
# FIXME: convert plottype into func: imshow, pcolor, pcolormesh, pcolorfast
mpl = ax.pcolorfast(data, vmin = vmin, vmax = vmax, cmap = cmap)
# normalize to [0, 1]
# mpl = ax.imshow(inv, interpolation = "none")
# mpl = ax.pcolorfast(data, vmin = vmin, vmax = vmax, cmap = cmap)
# mpl = ax.pcolorfast(data, vmin = vmins[j], vmax = vmaxs[j], cmap = cmap)
# mpl = ax.pcolorfast(data, vmin = -2, vmax = 2, cmap = cmap)
# mpl = ax.pcolormesh(data, cmap = cmap)
# mpl = ax.pcolor(data)
# mpl = ax.pcolorfast(data)
# mpl = ax.imshow(data, interpolation = "none")
# mpl = ax.pcolormesh(
# data,
# norm = mplcolors.LogNorm(vmin=data.min(), vmax=data.max())
ax.grid(0)
if kwargs.has_key('aspect'):
ax.set_aspect(kwargs['aspect'])
if kwargs.has_key('colorbar'):
if kwargs['colorbar']:
plt.colorbar(mappable = mpl, ax = ax, orientation = "horizontal")
if kwargs.has_key('title'):
ax.set_title(title) # , fontsize=8)
else:
ax.set_title("%s" % ('matrix')) # , fontsize=8)
# if kwargs.has_key('xlabel'):
ax.set_xlabel("")
# if kwargs.has_key('ylabel'):
ax.set_ylabel("")
# if kwargs.has_key('xticks'):
ax.set_xticks([])
# if kwargs.has_key('yticks'):
ax.set_yticks([])
def plot_all_chan_spectrum(spectrum, bins, *, ax=None, **kwargs):
def integrate_to_angles(spectrum, bins, lo, hi):
lo_ind, hi_ind = bins.searchsorted([lo, hi])
return spectrum[lo_ind:hi_ind].sum(axis=0)
if ax is None:
fig, ax = plt.subplots(figsize=(13.5, 9.5))
else:
fig = ax.figure
div = make_axes_locatable(ax)
ax_r = div.append_axes('right', 2, pad=0.1, sharey=ax)
ax_t = div.append_axes('top', 2, pad=0.1, sharex=ax)
ax_r.yaxis.tick_right()
ax_r.yaxis.set_label_position("right")
ax_t.xaxis.tick_top()
ax_t.xaxis.set_label_position("top")
im = ax.imshow(spectrum, origin='lower', aspect='auto',
extent=(-.5, 383.5,
bins[0], bins[-1]),
norm=LogNorm())
e_line, = ax_r.plot(spectrum.sum(axis=1), bins[:-1] + np.diff(bins))
p_line, = ax_t.plot(spectrum.sum(axis=0))
label = ax_t.annotate('[0, 70] kEv', (0, 1), (10, -10),
xycoords='axes fraction',
textcoords='offset pixels',
va='top', ha='left')
def update(lo, hi):
p_data = integrate_to_angles(spectrum, bins, lo, hi)
p_line.set_ydata(p_data)
ax_t.relim()
ax_t.autoscale(axis='y')
label.set_text(f'[{lo:.1f}, {hi:.1f}] keV')
fig.canvas.draw_idle()
span = SpanSelector(ax_r, update, 'vertical', useblit=True,
rectprops={'alpha': .5, 'facecolor': 'red'},
span_stays=True)
ax.set_xlabel('channel [#]')
ax.set_ylabel('E [keV]')
ax_t.set_xlabel('channel [#]')
ax_t.set_ylabel('total counts')
ax_r.set_ylabel('E [keV]')
ax_r.set_xlabel('total counts')
ax.set_xlim(-.5, 383.5)
ax.set_ylim(bins[0], bins[-1])
ax_r.set_xlim(xmin=0)
return spectrum, bins, {'center': {'ax': ax, 'im': im},
'top': {'ax': ax_t, 'p_line': p_line},
'right': {'ax': ax_r, 'e_line': e_line,
'span': span}}