def test_download_prediction_csv_class_prob(driver, project, dataset,
featureset, model, prediction):
driver.get('/')
_click_download(project.id, driver)
assert os.path.exists('/tmp/cesium_prediction_results.csv')
try:
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
npt.assert_array_equal(result.ts_name, np.arange(5))
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
'Mira', 'Classical_Cepheid',
'Mira'])
pred_probs = result[['Classical_Cepheid', 'Mira']]
npt.assert_array_equal(np.argmax(pred_probs.values, axis=1),
[1, 0, 1, 0, 1])
assert (pred_probs.values >= 0.0).all()
finally:
os.remove('/tmp/cesium_prediction_results.csv')
评论列表
文章目录