def _check_transformer_output(transformer, dataset, expected):
"""
Given a transformer and a spark dataset, check if the transformer
produces the expected results.
"""
analyzed_df = tfs.analyze(dataset)
out_df = transformer.transform(analyzed_df)
# Collect transformed values
out_colnames = list(_output_mapping.values())
_results = []
for row in out_df.select(out_colnames).collect():
curr_res = [row[colname] for colname in out_colnames]
_results.append(np.ravel(curr_res))
out_tgt = np.hstack(_results)
_err_msg = 'not close => shape {} != {}, max_diff {} > {}'
max_diff = np.max(np.abs(expected - out_tgt))
err_msg = _err_msg.format(expected.shape, out_tgt.shape,
max_diff, _all_close_tolerance)
assert np.allclose(expected, out_tgt, atol=_all_close_tolerance), err_msg
tf_transformer_test.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录