deprecated_flu_prediction.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:augur 作者: nextstrain 项目源码 文件源码
def plot_prediction(self):
        '''
        plots the global frequencies, the predicted frequencies, and the frequencies
        in the short interval used for learning.
        '''
        from matplotlib import pyplot as plt
        import seaborn as sns
        fig, axs = plt.subplots(1,2, figsize=(12,6))

        axs[0].plot(self.t_cut*np.ones(2), [0,1], lw=3, alpha=0.3, c='k', ls='--')
        axs[0].plot(self.current_prediction_interval[1]*np.ones(2), [0,1], lw=3, alpha=0.3, c='k')

        train_pivots = self.train_frequencies[self.current_prediction_interval][0]
        train_freqs = self.train_frequencies[self.current_prediction_interval][1]
        cols = sns.color_palette()
        future_pivots = self.global_pivots>train_pivots[-1]
        for node in self.predictions:
            if np.max(self.predictions[node][self.global_pivots>train_pivots[0]])>0.02:
                #print(self.predictions[t_cut_val][node])
                axs[0].plot(self.global_pivots[future_pivots],
                            self.predictions[node][future_pivots], ls='--', c=cols[node.clade%6])
                axs[0].plot(self.global_pivots, self.global_freqs[node.clade], ls='-', c=cols[node.clade%6])
                axs[0].plot(train_pivots, train_freqs[node.clade], ls='-.', c=cols[node.clade%6])

        axs[0].set_xlim(train_pivots[0]-2, train_pivots[-1]+2)
        dev = self.prediction_error()
        dev[~future_pivots]=0.0
        axs[1].plot(self.global_pivots, dev)
        axs[1].set_xlim(train_pivots[0], train_pivots[-1]+2)
        axs[1].set_ylim(0, 3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号