def trace_plot(self,figsize=(15,15)):
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import seaborn as sns
if hasattr(self.z_list[0], 'sample'):
fig = plt.figure(figsize=figsize)
palette = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
(0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
(0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
(0.5058823529411764, 0.4470588235294118, 0.6980392156862745),
(0.8, 0.7254901960784313, 0.4549019607843137),
(0.39215686274509803, 0.7098039215686275, 0.803921568627451)] * len(self.z_list)
for j in range(len(self.z_list)):
chain = self.z_list[j].sample
for k in range(4):
iteration = j*4 + k + 1
ax = fig.add_subplot(len(self.z_list),4,iteration)
if iteration in range(1,len(self.z_list)*4 + 1,4):
a = sns.distplot(self.z_list[j].prior.transform(chain), rug=False, hist=False,color=palette[j])
a.set_ylabel(self.z_list[j].name)
if iteration == 1:
a.set_title('Density Estimate')
elif iteration in range(2,len(self.z_list)*4 + 1,4):
a = plt.plot(self.z_list[j].prior.transform(chain),color=palette[j])
if iteration == 2:
plt.title('Trace Plot')
elif iteration in range(3,len(self.z_list)*4 + 1,4):
plt.plot(np.cumsum(self.z_list[j].prior.transform(chain))/np.array(range(1,len(chain)+1)),color=palette[j])
if iteration == 3:
plt.title('Cumulative Average')
elif iteration in range(4,len(self.z_list)*4 + 1,4):
plt.bar(range(1,10),[acf(chain,lag) for lag in range(1,10)],color=palette[j])
if iteration == 4:
plt.title('ACF Plot')
sns.plt.show()
else:
raise ValueError("No samples to plot!")
评论列表
文章目录