def plot_marginals(state_space,p,name,t,labels = False,interactive = False):
import matplotlib
import matplotlib.pyplot as pl
if interactive == True:
pl.ion()
pl.clf()
pl.suptitle("time: "+ str(t)+" units")
#print("time : "+ str(t))
D = state_space.shape[1]
for i in range(D):
marg_X = np.unique(state_space[:,i])
A = np.where(marg_X[:,np.newaxis] == state_space[:,i].T[np.newaxis,:],1,0)
marg_p = np.dot(A,p)
pl.subplot(int(D/2)+1,2,i+1)
pl.plot(marg_X,marg_p)
pl.yticks(np.linspace(np.amin(marg_p), np.amax(marg_p), num=3))
pl.axvline(np.sum(marg_X*marg_p),color= 'r')
pl.axvline(marg_X[np.argmax(marg_p)],color='g')
if labels == False:
pl.xlabel("Specie: " + str(i+1))
else:
pl.xlabel(labels[i])
if interactive == True:
pl.draw()
else:
pl.tight_layout()
pl.show()
评论列表
文章目录