latent_variables.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyflux 作者: RJT1990 项目源码 文件源码
def trace_plot(self,figsize=(15,15)):
        import matplotlib.pyplot as plt
        import matplotlib.mlab as mlab
        import seaborn as sns

        if hasattr(self.z_list[0], 'sample'):
            fig = plt.figure(figsize=figsize)

            palette = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), 
            (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), 
            (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), 
            (0.5058823529411764, 0.4470588235294118, 0.6980392156862745), 
            (0.8, 0.7254901960784313, 0.4549019607843137), 
            (0.39215686274509803, 0.7098039215686275, 0.803921568627451)] * len(self.z_list)

            for j in range(len(self.z_list)):
                chain = self.z_list[j].sample
                for k in range(4):
                    iteration = j*4 + k + 1
                    ax = fig.add_subplot(len(self.z_list),4,iteration)
                    if iteration in range(1,len(self.z_list)*4 + 1,4):
                        a = sns.distplot(self.z_list[j].prior.transform(chain), rug=False, hist=False,color=palette[j])
                        a.set_ylabel(self.z_list[j].name)
                        if iteration == 1:
                            a.set_title('Density Estimate')
                    elif iteration in range(2,len(self.z_list)*4 + 1,4):
                        a = plt.plot(self.z_list[j].prior.transform(chain),color=palette[j])
                        if iteration == 2:
                            plt.title('Trace Plot')
                    elif iteration in range(3,len(self.z_list)*4 + 1,4): 
                        plt.plot(np.cumsum(self.z_list[j].prior.transform(chain))/np.array(range(1,len(chain)+1)),color=palette[j])
                        if iteration == 3:
                            plt.title('Cumulative Average')                 
                    elif iteration in range(4,len(self.z_list)*4 + 1,4):
                        plt.bar(range(1,10),[acf(chain,lag) for lag in range(1,10)],color=palette[j])
                        if iteration == 4:
                            plt.title('ACF Plot')                       
            sns.plt.show()  
        else:
            raise ValueError("No samples to plot!")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号