def test_download_prediction_csv_regr(driver, project, dataset, featureset, model, prediction):
driver.get('/')
_click_download(project.id, driver)
assert os.path.exists('/tmp/cesium_prediction_results.csv')
try:
results = np.genfromtxt('/tmp/cesium_prediction_results.csv',
dtype='str', delimiter=',')
npt.assert_equal(results[0],
['ts_name', 'label', 'prediction'])
npt.assert_array_almost_equal(
[[float(e) for e in row] for row in results[1:]],
[[0, 2.2, 2.2],
[1, 3.4, 3.4],
[2, 4.4, 4.4],
[3, 2.2, 2.2],
[4, 3.1, 3.1]])
finally:
os.remove('/tmp/cesium_prediction_results.csv')
评论列表
文章目录