titer_model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:augur 作者: nextstrain 项目源码 文件源码
def validate(self, plot=False, cutoff=0.0, validation_set = None, fname=None):
        '''
        predict titers of the validation set (separate set of test_titers aside previously)
        and compare against known values. If requested by plot=True,
        a figure comparing predicted and measured titers is produced
        '''
        from scipy.stats import linregress, pearsonr
        if validation_set is None:
            validation_set=self.test_titers
        self.validation = {}
        for key, val in validation_set.iteritems():
            pred_titer = self.predict_titer(key[0], key[1], cutoff=cutoff)
            self.validation[key] = (val, pred_titer)

        a = np.array(self.validation.values())
        print ("number of prediction-measurement pairs",a.shape)
        self.abs_error = np.mean(np.abs(a[:,0]-a[:,1]))
        self.rms_error = np.sqrt(np.mean((a[:,0]-a[:,1])**2))
        self.slope, self.intercept, tmpa, tmpb, tmpc = linregress(a[:,0], a[:,1])
        print ("error (abs/rms): ",self.abs_error, self.rms_error)
        print ("slope, intercept:", self.slope, self.intercept)
        self.r2 = pearsonr(a[:,0], a[:,1])[0]**2
        print ("pearson correlation:", self.r2)

        if plot:
            import matplotlib.pyplot as plt
            import seaborn as sns
            fs=16
            sns.set_style('darkgrid')
            plt.figure()
            ax = plt.subplot(111)
            plt.plot([-1,6], [-1,6], 'k')
            plt.scatter(a[:,0], a[:,1])
            plt.ylabel(r"predicted $\log_2$ distance", fontsize = fs)
            plt.xlabel(r"measured $\log_2$ distance" , fontsize = fs)
            ax.tick_params(axis='both', labelsize=fs)
            plt.text(-2.5,6,'regularization:\nprediction error:\nR^2:', fontsize = fs-2)
            plt.text(1.2,6, str(self.lam_drop)+'/'+str(self.lam_pot)+'/'+str(self.lam_avi)+' (HI/pot/avi)'
                     +'\n'+str(round(self.abs_error, 2))+'/'+str(round(self.rms_error, 2))+' (abs/rms)'
                     + '\n' + str(self.r2), fontsize = fs-2)
            plt.tight_layout()

            if fname:
                plt.savefig(fname)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号