tf_transformer_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:spark-deep-learning 作者: databricks 项目源码 文件源码
def _check_transformer_output(transformer, dataset, expected):
    """
    Given a transformer and a spark dataset, check if the transformer
    produces the expected results.
    """
    analyzed_df = tfs.analyze(dataset)
    out_df = transformer.transform(analyzed_df)

    # Collect transformed values
    out_colnames = list(_output_mapping.values())
    _results = []
    for row in out_df.select(out_colnames).collect():
        curr_res = [row[colname] for colname in out_colnames]
        _results.append(np.ravel(curr_res))
    out_tgt = np.hstack(_results)

    _err_msg = 'not close => shape {} != {}, max_diff {} > {}'
    max_diff = np.max(np.abs(expected - out_tgt))
    err_msg = _err_msg.format(expected.shape, out_tgt.shape,
                              max_diff, _all_close_tolerance)
    assert np.allclose(expected, out_tgt, atol=_all_close_tolerance), err_msg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号