def plot_2D_heat_map(states,p,labels, inter=False):
import pylab as pl
X = np.unique(states[0,:])
Y = np.unique(states[1,:])
X_len = len(X)
Y_len = len(Y)
Z = np.zeros((X.max()+1,Y.max()+1))
for i in range(len(p)):
Z[states[0,i],states[1,i]] = p[i]
pl.clf()
pl.imshow(Z.T, origin='lower')
pl.xlabel(labels[0])
pl.ylabel(labels[1])
if inter== True:
pl.draw()
else:
pl.show()
评论列表
文章目录