metrics.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def compute_error(y, m, v, lik, median=False, no_samples=50):
    if lik == 'gauss':
        y = y.reshape((y.shape[0],))
        if median:
            rmse = np.sqrt(np.median((y - m)**2))
        else:
            rmse = np.sqrt(np.mean((y - m)**2))
        return rmse
    elif lik == 'cdf':
        K = no_samples
        fs = draw_randn_samples(K, m, v).T
        log_factor = stats.norm.logcdf(np.tile(y.reshape((y.shape[0], 1)), (1, K)) * fs)
        ll = logsumexp(log_factor - np.log(K), 1)
        return 1 - np.mean(ll > np.log(0.5))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号