def plot_prediction(self):
'''
plots the global frequencies, the predicted frequencies, and the frequencies
in the short interval used for learning.
'''
from matplotlib import pyplot as plt
import seaborn as sns
fig, axs = plt.subplots(1,2, figsize=(12,6))
axs[0].plot(self.t_cut*np.ones(2), [0,1], lw=3, alpha=0.3, c='k', ls='--')
axs[0].plot(self.current_prediction_interval[1]*np.ones(2), [0,1], lw=3, alpha=0.3, c='k')
train_pivots = self.train_frequencies[self.current_prediction_interval][0]
train_freqs = self.train_frequencies[self.current_prediction_interval][1]
cols = sns.color_palette()
future_pivots = self.global_pivots>train_pivots[-1]
for node in self.predictions:
if np.max(self.predictions[node][self.global_pivots>train_pivots[0]])>0.02:
#print(self.predictions[t_cut_val][node])
axs[0].plot(self.global_pivots[future_pivots],
self.predictions[node][future_pivots], ls='--', c=cols[node.clade%6])
axs[0].plot(self.global_pivots, self.global_freqs[node.clade], ls='-', c=cols[node.clade%6])
axs[0].plot(train_pivots, train_freqs[node.clade], ls='-.', c=cols[node.clade%6])
axs[0].set_xlim(train_pivots[0]-2, train_pivots[-1]+2)
dev = self.prediction_error()
dev[~future_pivots]=0.0
axs[1].plot(self.global_pivots, dev)
axs[1].set_xlim(train_pivots[0], train_pivots[-1]+2)
axs[1].set_ylim(0, 3)
评论列表
文章目录