def hrd_key(self, key_str):
"""
plot an HR diagram
Parameters
----------
key_str : string
A label string
"""
pyl.plot(self.data[:,self.cols['log_Teff']-1],\
self.data[:,self.cols['log_L']-1],label = key_str)
pyl.legend()
pyl.xlabel('log Teff')
pyl.ylabel('log L')
x1,x2=pl.xlim()
if x2 > x1:
self._xlimrev()
python类plot()的实例源码
def compare_images(path = '.'):
S_limit = 10.
file_list = glob.glob(os.path.join(path, 'Abu*'))
file_list_master = glob.glob(os.path.join(path, 'MasterAbu*'))
file_list.sort()
file_list_master.sort()
S=[]
print("Identifying images with rmq > "+'%3.1f'%S_limit)
ierr_count = 0
for i in range(len(file_list)):
this_S,fimg1,fimg2 = compare_entropy(file_list[i],file_list_master[i])
if this_S > S_limit:
warnings.warn(file_list[i]+" and "+file_list_master[i]+" differ by "+'%6.3f'%this_S)
ierr_count += 1
S.append(this_S)
if ierr_count > 0:
print("Error: at least one image differs by more than S_limit")
sys.exit(1)
#print ("S: ",S)
#plb.plot(S,'o')
#plb.xlabel("image number")
#plb.ylabel("modified log KL-divergence to previous image")
#plb.show()
def plot_prof_2(self, mod, species, xlim1, xlim2):
"""
Plot one species for cycle between xlim1 and xlim2
Parameters
----------
mod : string or integer
Model to plot, same as cycle number.
species : list
Which species to plot.
xlim1, xlim2 : float
Mass coordinate range.
"""
mass=self.se.get(mod,'mass')
Xspecies=self.se.get(mod,'yps',species)
pyl.plot(mass,Xspecies,'-',label=str(mod)+', '+species)
pyl.xlim(xlim1,xlim2)
pyl.legend()
def plot(traj, x, y, **kwargs):
""" Create a matplotlib plot of property x against property y
Args:
x,y (str): names of the properties
**kwargs (dict): kwargs for :meth:`matplotlib.pylab.plot`
Returns:
List[matplotlib.lines.Lines2D]: the lines that were plotted
"""
from matplotlib import pylab
xl = yl = None
if type(x) is str:
strx = x
x = getattr(traj, x)
xl = '%s / %s' % (strx, getattr(x, 'units', 'dimensionless'))
if type(y) is str:
stry = y
y = getattr(traj, y)
yl = '%s / %s' % (stry, getattr(y, 'units', 'dimensionless'))
plt = pylab.plot(x, y, **kwargs)
pylab.xlabel(xl); pylab.ylabel(yl); pylab.grid()
return plt
def plot(m, Xtrain, ytrain):
xx = np.linspace(-0.5, 1.5, 100)[:, None]
mean, var = m.predict_y(xx)
mean = np.reshape(mean, (xx.shape[0], 1))
var = np.reshape(var, (xx.shape[0], 1))
if isinstance(m, aep.SDGPR):
zu = m.sgp_layers[0].zu
elif isinstance(m, vfe.SGPR_collapsed):
zu = m.zu
else:
zu = m.sgp_layer.zu
mean_u, var_u = m.predict_f(zu)
plt.figure()
plt.plot(Xtrain, ytrain, 'kx', mew=2)
plt.plot(xx, mean, 'b', lw=2)
# pdb.set_trace()
plt.fill_between(
xx[:, 0],
mean[:, 0] - 2 * np.sqrt(var[:, 0]),
mean[:, 0] + 2 * np.sqrt(var[:, 0]),
color='blue', alpha=0.2)
plt.errorbar(zu, mean_u, yerr=2 * np.sqrt(var_u), fmt='ro')
plt.xlim(-0.1, 1.1)
def run_regression_1D_aep_two_layers():
np.random.seed(42)
print "create dataset ..."
Xtrain, ytrain, Xtest, ytest = create_dataset()
alpha = 1 # other alpha is not valid here
M = 20
model = aep.SDGPR(Xtrain, ytrain, M, hidden_sizes=[2])
model.optimise(method='L-BFGS-B', alpha=1, maxiter=5000, disp=False)
my, vy = model.predict_y(Xtest)
my = np.reshape(my, ytest.shape)
vy = np.reshape(vy, ytest.shape)
rmse = np.sqrt(np.mean((my - ytest)**2))
ll = np.mean(-0.5 * np.log(2 * np.pi * vy) - 0.5 * (ytest - my)**2 / vy)
nlml, _ = model.objective_function(model.get_hypers(), Xtrain.shape[0], alpha)
print 'alpha=%.3f, train ml=%3f, test rmse=%.3f, ll=%.3f' % (alpha, nlml, rmse, ll)
# plot(model, Xtrain, ytrain)
# plt.show()
# should produce something like this
# alpha=1.000, train ml=-51.385404, test rmse=0.168, ll=0.311
def run_regression_1D_aep_two_layers_stoc():
np.random.seed(42)
print "create dataset ..."
Xtrain, ytrain, Xtest, ytest = create_dataset()
alpha = 1 # other alpha is not valid here
M = 20
model = aep.SDGPR(Xtrain, ytrain, M, hidden_sizes=[2])
model.optimise(method='adam', alpha=1, maxiter=5000, disp=False)
my, vy = model.predict_y(Xtest)
my = np.reshape(my, ytest.shape)
vy = np.reshape(vy, ytest.shape)
rmse = np.sqrt(np.mean((my - ytest)**2))
ll = np.mean(-0.5 * np.log(2 * np.pi * vy) - 0.5 * (ytest - my)**2 / vy)
nlml, _ = model.objective_function(model.get_hypers(), Xtrain.shape[0], alpha)
print 'alpha=%.3f, train ml=%3f, test rmse=%.3f, ll=%.3f' % (alpha, nlml, rmse, ll)
# plot(model, Xtrain, ytrain)
# plt.show()
# should produce something like this
# alpha=1.000, train ml=-69.444086, test rmse=0.170, ll=0.318
def plot(param, show = 1):
"""Returns the plot of spectrum as a pyplot object or plot it on the screen
Keyword arguments:
param -- Output spectrum file
show -- Optional, plot the spectrum on the screen. Enabled by default.
"""
s = sed.SED()
s.grmonty(param)
plt = pylab.plot(s.lognu, s.ll)
if show == 1:
pylab.show()
else:
return plt
def plot_position(self, pos_true, pos_est):
N = pos_est.shape[1]
pos_true = pos_true[:, :N]
pos_est = pos_est[:, :N]
# Figure
plt.figure()
plt.suptitle("Position")
# Ground truth
plt.plot(pos_true[0, :], pos_true[1, :],
color="red", marker="o", label="Grouth truth")
# Estimated
plt.plot(pos_est[0, :], pos_est[1, :],
color="blue", marker="o", label="Estimated")
# Plot labels and legends
plt.xlabel("East (m)")
plt.ylabel("North (m)")
plt.axis("equal")
plt.legend(loc=0)
def test_project(self):
# Load points
points_file = join(test.TEST_DATA_PATH, "house/house.p3d")
points = np.loadtxt(points_file).T
# Setup camera
K = np.eye(3)
R = np.eye(3)
t = np.array([0, 0, 0])
camera = PinholeCameraModel(320, 240, K)
x = camera.project(points, R, t)
# Assert
self.assertEqual(x.shape, (3, points.shape[1]))
self.assertTrue(np.all(x[2, :] == 1.0))
# Plot projection
debug = False
# debug = True
if debug:
plt.figure()
plt.plot(x[0], x[1], 'k. ')
plt.show()
def plot(self, track, track_cam_states, estimates):
plt.figure()
# Feature
feature = T_global_camera * track.ground_truth
plt.plot(feature[0], feature[1],
marker="o", color="red", label="feature")
# Camera states
for cam_state in track_cam_states:
pos = T_global_camera * cam_state.p_G
plt.plot(pos[0], pos[1],
marker="o", color="blue", label="camera")
# Estimates
for i in range(len(estimates)):
cam_state = track_cam_states[i]
cam_pos = T_global_camera * cam_state.p_G
estimate = (T_global_camera * estimates[i]) + cam_pos
plt.plot(estimate[0], estimate[1],
marker="o", color="green")
plt.legend(loc=0)
plt.show()
def x_corr(a,b,center_time_s=1000.0,window_len_s=50.0,plot=True):
center_index = int(center_time_s/a.dt)
window_index = int(window_len_s/(a.dt))
print "center_index is", center_index
print "window_index is", window_index
t1 = a.trace_x[(center_index - window_index) : (center_index + window_index)]
t2 = b.trace_x[(center_index - window_index) : (center_index + window_index)]
print t1
time_window = np.linspace((-window_len_s/2.0), (window_len_s/2), len(t1))
#print time_window
#plt.plot(time_window, t1)
#plt.plot(time_window, t2)
#plt.show()
x_corr_time = correlate(t1, t2)
delay = (np.argmax(x_corr_time) - (len(x_corr_time)/2) ) * a.dt
#print "the delay is ", delay
return delay
def plotValueFunction(self, valueFunction, prefix):
'''3d plot of a value function.'''
fig, ax = plt.subplots(subplot_kw = dict(projection = '3d'))
X, Y = np.meshgrid(np.arange(self.numCols), np.arange(self.numRows))
Z = valueFunction.reshape(self.numRows, self.numCols)
for i in xrange(len(X)):
for j in xrange(len(X[i])/2):
tmp = X[i][j]
X[i][j] = X[i][len(X[i]) - j - 1]
X[i][len(X[i]) - j - 1] = tmp
my_col = cm.jet(np.random.rand(Z.shape[0],Z.shape[1]))
ax.plot_surface(X, Y, Z, rstride = 1, cstride = 1,
cmap = plt.get_cmap('jet'))
plt.gca().view_init(elev=30, azim=30)
plt.savefig(self.outputPath + prefix + 'value_function.png')
plt.close()
def plotLine(self, x_vals, y_vals, x_label, y_label, title, filename=None):
plt.clf()
plt.xlabel(x_label)
plt.xlim(((min(x_vals) - 0.5), (max(x_vals) + 0.5)))
plt.ylabel(y_label)
plt.ylim(((min(y_vals) - 0.5), (max(y_vals) + 0.5)))
plt.title(title)
plt.plot(x_vals, y_vals, c='k', lw=2)
#plt.plot(x_vals, len(x_vals) * y_vals[0], c='r', lw=2)
if filename == None:
plt.show()
else:
plt.savefig(self.outputPath + filename)
demo_mi.py 文件源码
项目:Building-Machine-Learning-Systems-With-Python-Second-Edition
作者: PacktPublishing
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def plot_entropy():
pylab.clf()
pylab.figure(num=None, figsize=(5, 4))
title = "Entropy $H(X)$"
pylab.title(title)
pylab.xlabel("$P(X=$coin will show heads up$)$")
pylab.ylabel("$H(X)$")
pylab.xlim(xmin=0, xmax=1.1)
x = np.arange(0.001, 1, 0.001)
y = -x * np.log2(x) - (1 - x) * np.log2(1 - x)
pylab.plot(x, y)
# pylab.xticks([w*7*24 for w in [0,1,2,3,4]], ['week %i'%(w+1) for w in
# [0,1,2,3,4]])
pylab.autoscale(tight=True)
pylab.grid(True)
filename = "entropy_demo.png"
pylab.savefig(os.path.join(CHART_DIR, filename), bbox_inches="tight")
utils.py 文件源码
项目:Building-Machine-Learning-Systems-With-Python-Second-Edition
作者: PacktPublishing
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def plot_roc(auc_score, name, tpr, fpr, label=None):
pylab.clf()
pylab.figure(num=None, figsize=(5, 4))
pylab.grid(True)
pylab.plot([0, 1], [0, 1], 'k--')
pylab.plot(fpr, tpr)
pylab.fill_between(fpr, tpr, alpha=0.5)
pylab.xlim([0.0, 1.0])
pylab.ylim([0.0, 1.0])
pylab.xlabel('False Positive Rate')
pylab.ylabel('True Positive Rate')
pylab.title('ROC curve (AUC = %0.2f) / %s' %
(auc_score, label), verticalalignment="bottom")
pylab.legend(loc="lower right")
filename = name.replace(" ", "_")
pylab.savefig(
os.path.join(CHART_DIR, "roc_" + filename + ".png"), bbox_inches="tight")
def plotKChart(self, misClassDict, saveFigPath):
kList = []
misRateList = []
for k, misClassNum in misClassDict.iteritems():
kList.append(k)
misRateList.append(1.0 - 1.0/k*misClassNum)
fig = plt.figure(saveFigPath)
plt.plot(kList, misRateList, 'r--')
plt.title(saveFigPath)
plt.xlabel('k Num.')
plt.ylabel('Misclassified Rate')
plt.legend(saveFigPath)
plt.grid(True)
plt.savefig(saveFigPath)
plt.show()
################################### PART3 TEST ########################################
# ??
def linear_testing():
x_axis = np.linspace(1, 51, 100)
x_nice = np.linspace(x_axis[0], x_axis[-1], 100)
mod, params = qudi_fitting.make_linear_model()
print('Parameters of the model', mod.param_names, ' with the independet variable', mod.independent_vars)
params['slope'].value = 2 # + abs(np.random.normal(0,1))
params['offset'].value = 50 #+ abs(np.random.normal(0, 200))
#print('\n', 'beta', params['beta'].value, '\n', 'lifetime',
#params['lifetime'].value)
data_noisy = (mod.eval(x=x_axis, params=params)
+ 10 * np.random.normal(size=x_axis.shape))
result = qudi_fitting.make_linear_fit(axis=x_axis, data=data_noisy, add_parameters=None)
plt.plot(x_axis, data_noisy, 'ob')
plt.plot(x_nice, mod.eval(x=x_nice, params=params), '-g')
print(result.fit_report())
plt.plot(x_axis, result.best_fit, '-r', linewidth=2.0)
plt.plot(x_axis, result.init_fit, '-y', linewidth=2.0)
plt.show()
def plot_penalty_vl(debug, tag, fold_exp):
plt.close("all")
vl = np.array(debug["penalty"])
fig = plt.figure(figsize=(15, 10.8), dpi=300)
names = debug["names"]
for i in range(vl.shape[1]):
if vl.shape[1] > 1:
plt.plot(vl[:, i], label="layer_"+str(names[i]))
else:
plt.plot(vl[:], label="layer_"+str(names[i]))
plt.xlabel("mini-batchs")
plt.ylabel("value of penlaty")
plt.title(
"Penalty value over layers:" + "_".join([str(k) for k in names]) +
". tag:" + tag)
plt.legend(loc='upper right', fancybox=True, shadow=True, prop={'size': 8})
plt.grid(True)
fig.savefig(fold_exp+"/penalty.png", bbox_inches='tight')
plt.close('all')
del fig
def add_cifar_10(x, x_cifar_10, sh=True):
"""Add cifar 10 as background."""
sz = x.shape
mask = (x == 0) * 1.
# binarize cifar
back = x_cifar_10.reshape(x_cifar_10.shape[0], 3, 32, 32).mean(1)
back = back[:, 2:30, 2:30] # take 28x28 from the center.
back /= 255.
back = back.astype(np.float32)
# shuffle the index
if sh:
ind = np.random.randint(0, x_cifar_10.shape[0], sz[0]) # the index
for i in range(10):
np.random.shuffle(ind)
else:
# used only to plot images for paper.
assert x_cifar_10.shape[0] == sz[0]
ind = np.arange(0, sz[0]) # the index
back_sh = back[ind]
back_sh = back_sh.reshape(back_sh.shape[0], -1)
back_ready = np.multiply(back_sh, mask)
out = np.clip(x + back_ready, 0., 1.)
return out
def plot_roc(y_test, y_pred, label=''):
"""Compute ROC curve and ROC area"""
fpr, tpr, _ = roc_curve(y_test, y_pred)
roc_auc = auc(fpr, tpr)
# Plot of a ROC curve for a specific class
plt.figure()
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic' + label)
plt.legend(loc="lower right")
plt.show()
def gen_overview_plot_image(ax,imagefile,imgext=0,cubelayer=1,title='Img Title?',fontsize=6,lthick=2,alpha=0.5,
cmap='coolwarm'):
"""
Plotting commands for image (cube layer) overview plotting
--- INPUT ---
cubelayer If the content of the file is a cube, provide the cube layer to plot. If
cubelayer = 'fmax' the layer with most flux is plotted
"""
ax.set_title(title,fontsize=fontsize)
if os.path.isfile(imagefile):
imgdata = pyfits.open(imagefile)[imgext].data
if len(imgdata.shape) == 3: # it is a cube
imgdata = imgdata[cubelayer,:,:]
ax.imshow(imgdata, interpolation='None',cmap=cmap,aspect='equal', origin='lower')
ax.set_xlabel('x-pixel')
ax.set_ylabel('y-pixel ')
ax.set_xticks([])
ax.set_yticks([])
else:
textstr = 'No image\nfound'
ax.text(1.0,22,textstr,horizontalalignment='center',verticalalignment='center',fontsize=fontsize)
ax.set_ylim([28,16])
ax.plot([0.0,2.0],[28,16],'r--',lw=lthick)
ax.plot([2.0,0.0],[28,16],'r--',lw=lthick)
ax.set_xlabel(' ')
ax.set_ylabel(' ')
ax.set_xticks([])
ax.set_yticks([])
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
def energy_profile(self,ixaxis):
"""
Plot radial profile of key energy generations eps_nuc,
eps_neu etc.
Parameters
----------
ixaxis : 'mass' or 'radius'
"""
mass = self.get('mass')
radius = self.get('radius') * ast.rsun_cm
eps_nuc = self.get('eps_nuc')
eps_neu = self.get('non_nuc_neu')
if ixaxis == 'mass':
xaxis = mass
xlab = 'Mass / M$_\odot$'
else:
xaxis = old_div(radius, 1.e8) # Mm
xlab = 'radius / Mm'
pl.plot(xaxis, np.log10(eps_nuc),
'k-',
label='$\epsilon_\mathrm{nuc}>0$')
pl.plot(xaxis, np.log10(-eps_nuc),
'k--',
label='$\epsilon_\mathrm{nuc}<0$')
pl.plot(xaxis, np.log10(eps_neu),
'r-',
label='$\epsilon_\\nu$')
pl.xlabel(xlab)
pl.ylabel('$\log(\epsilon_\mathrm{nuc},\epsilon_\\nu)$')
pl.legend(loc='best').draw_frame(False)
def CO_ratio(self,ifig,ixaxis):
"""
plot surface C/O ratio in Figure ifig with x-axis quantity ixaxis
Parameters
----------
ifig : integer
Figure number in which to plot
ixaxis : string
what quantity is to be on the x-axis, either 'time' or 'model'
The default is 'model'
"""
def C_O(model):
surface_c12=model.get('surface_c12')
surface_o16=model.get('surface_o16')
CORatio=old_div((surface_c12*4.),(surface_o16*3.))
return CORatio
if ixaxis=='time':
xax=self.get('star_age')
elif ixaxis=='model':
xax=self.get('model_number')
else:
raise IOError("ixaxis not recognised")
pl.figure(ifig)
pl.plot(xax,C_O(self))
def hrd_new(self, input_label="", skip=0):
"""
plot an HR diagram with options to skip the first N lines and
add a label string
Parameters
----------
input_label : string, optional
Diagram label. The default is "".
skip : integer, optional
Skip the first n lines. The default is 0.
"""
xl_old=pyl.gca().get_xlim()
if input_label == "":
my_label="M="+str(self.header_attr['initial_mass'])+", Z="+str(self.header_attr['initial_z'])
else:
my_label="M="+str(self.header_attr['initial_mass'])+", Z="+str(self.header_attr['initial_z'])+"; "+str(input_label)
pyl.plot(self.data[skip:,self.cols['log_Teff']-1],self.data[skip:,self.cols['log_L']-1],label = my_label)
pyl.legend(loc=0)
xl_new=pyl.gca().get_xlim()
pyl.xlabel('log Teff')
pyl.ylabel('log L')
if any(array(xl_old)==0):
pyl.gca().set_xlim(max(xl_new),min(xl_new))
elif any(array(xl_new)==0):
pyl.gca().set_xlim(max(xl_old),min(xl_old))
else:
pyl.gca().set_xlim([max(xl_old+xl_new),min(xl_old+xl_new)])
def t_lumi(self,num_frame,xax):
"""
Luminosity evolution as a function of time or model.
Parameters
----------
num_frame : integer
Number of frame to plot this plot into.
xax : string
Either model or time to indicate what is to be used on the
x-axis
"""
pyl.figure(num_frame)
if xax == 'time':
xaxisarray = self.get('star_age')
elif xax == 'model':
xaxisarray = self.get('model_number')
else:
print('kippenhahn_error: invalid string for x-axis selction. needs to be "time" or "model"')
logLH = self.get('log_LH')
logLHe = self.get('log_LHe')
pyl.plot(xaxisarray,logLH,label='L_(H)')
pyl.plot(xaxisarray,logLHe,label='L(He)')
pyl.ylabel('log L')
pyl.legend(loc=2)
if xax == 'time':
pyl.xlabel('t / yrs')
elif xax == 'model':
pyl.xlabel('model number')
def plot_prof_1(self, mod, species, xlim1, xlim2, ylim1, ylim2,
symbol=None):
"""
plot one species for cycle between xlim1 and xlim2
Parameters
----------
mod : string or integer
Model to plot, same as cycle number.
species : list
Which species to plot.
xlim1, xlim2 : float
Mass coordinate range.
ylim1, ylim2 : float
Mass fraction coordinate range.
symbol : string, optional
Which symbol you want to use. If None symbol is set to '-'.
The default is None.
"""
DataPlot.plot_prof_1(self,species,mod,xlim1,xlim2,ylim1,ylim2,symbol)
"""
tot_mass=self.se.get(mod,'total_mass')
age=self.se.get(mod,'age')
mass=self.se.get(mod,'mass')
Xspecies=self.se.get(mod,'iso_massf',species)
pyl.plot(mass,np.log10(Xspecies),'-',label=species)
pyl.xlim(xlim1,xlim2)
pyl.ylim(ylim1,ylim2)
pyl.legend()
pl.xlabel('$Mass$ $coordinate$', fontsize=20)
pl.ylabel('$X_{i}$', fontsize=20)
pl.title('Mass='+str(tot_mass)+', Time='+str(age)+' years, cycle='+str(mod))
"""
def plot_prof_sparse(self, mod, species, xlim1, xlim2, ylim1, ylim2,
sparse, symbol):
"""
plot one species for cycle between xlim1 and xlim2.
Parameters
----------
species : list
which species to plot.
mod : string or integer
Model (cycle) to plot.
xlim1, xlim2 : float
Mass coordinate range.
ylim1, ylim2 : float
Mass fraction coordinate range.
sparse : integer
Sparsity factor for points.
symbol : string
which symbol you want to use?
"""
mass=self.se.get(mod,'mass')
Xspecies=self.se.get(mod,'yps',species)
pyl.plot(mass[0:len(mass):sparse],np.log10(Xspecies[0:len(Xspecies):sparse]),symbol)
pyl.xlim(xlim1,xlim2)
pyl.ylim(ylim1,ylim2)
pyl.legend()
def abup_se_plot(mod,species):
"""
plot species from one ABUPP file and the se file.
You must use this function in the directory where the ABP files
are and an ABUP file for model mod must exist.
Parameters
----------
mod : integer
Model to plot, you need to have an ABUPP file for that
model.
species : string
The species to plot.
Notes
-----
The species is set to 'C-12'.
"""
# Marco, you have already implemented finding headers and columns in
# ABUP files. You may want to transplant that into here?
species='C-12'
filename = 'ABUPP%07d0000.DAT' % mod
print(filename)
mass,c12=np.loadtxt(filename,skiprows=4,usecols=[1,18],unpack=True)
c12_se=self.se.get(mod,'iso_massf','C-12')
mass_se=self.se.get(mod,'mass')
pyl.plot(mass,c12)
pyl.plot(mass_se,c12_se,'o',label='cycle '+str(mod))
pyl.legend()
draw.py 文件源码
项目:uai2017_learning_to_acquire_information
作者: evanthebouncy
项目源码
文件源码
阅读 27
收藏 0
点赞 0
评论 0
def draw(m, name, extra=None):
FIG.clf()
matrix = m
orig_shape = np.shape(matrix)
# lose the channel shape in the end of orig_shape
new_shape = orig_shape[:-1]
matrix = np.reshape(matrix, new_shape)
ax = FIG.add_subplot(1,1,1)
ax.set_aspect('equal')
plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.gray)
# plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.ocean)
plt.colorbar()
if extra != None:
greens, reds = extra
grn_x, grn_y, = greens
red_x, red_y = reds
plt.scatter(x=grn_x, y=grn_y, c='g', s=40)
plt.scatter(x=red_x, y=red_y, c='r', s=40)
# # put a blue dot at (10, 20)
# plt.scatter([10], [20])
# # put a red dot, size 40, at 2 locations:
# plt.scatter(x=[3, 4], y=[5, 6], c='r', s=40)
# # plt.plot()
plt.savefig(name)