test_predict.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:cesium_web 作者: cesium-ml 项目源码 文件源码
def test_download_prediction_csv_class_prob(driver, project, dataset,
                                            featureset, model, prediction):
    driver.get('/')
    _click_download(project.id, driver)
    assert os.path.exists('/tmp/cesium_prediction_results.csv')
    try:
        result = pd.read_csv('/tmp/cesium_prediction_results.csv')
        npt.assert_array_equal(result.ts_name, np.arange(5))
        npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
                                              'Mira', 'Classical_Cepheid',
                                              'Mira'])
        pred_probs = result[['Classical_Cepheid', 'Mira']]
        npt.assert_array_equal(np.argmax(pred_probs.values, axis=1),
                               [1, 0, 1, 0, 1])
        assert (pred_probs.values >= 0.0).all()
    finally:
        os.remove('/tmp/cesium_prediction_results.csv')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号