format.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def export_rows(schema, data):
    """Export multiple rows of internal data to json format.

    Args:
        schema: A schema dict as returned by load_schema().
        data: An [N, R]-shaped numpy array of ragged data, where N is the
            number of rows and R = schema['ragged_index'][-1].

    Returns:
        A N-long list of sparse dicts mapping feature names to json values,
        where N is the number of rows.
    """
    logger.debug('Exporting {:d} rows', data.shape[0])
    assert data.dtype == np.int8
    assert len(data.shape) == 2
    ragged_index = schema['ragged_index']
    assert data.shape[1] == ragged_index[-1]
    feature_names = schema['feature_names']
    feature_types = schema['feature_types']
    categorical_values = schema['categorical_values']
    ordinal_ranges = schema['ordinal_ranges']

    rows = [{} for _ in range(data.shape[0])]
    for external_row, internal_row in zip(rows, data):
        for v, name in enumerate(feature_names):
            beg, end = ragged_index[v:v + 2]
            internal_cell = internal_row[beg:end]
            if np.all(internal_cell == 0):
                continue
            typename = feature_types[name]
            if typename == CATEGORICAL:
                assert internal_cell.sum() == 1, internal_cell
                value = categorical_values[name][internal_cell.argmax()]
            elif typename == ORDINAL:
                min_max = ordinal_ranges[name]
                assert internal_cell.sum() == min_max[1] - min_max[0]
                value = internal_cell[0] + min_max[0]
            else:
                raise ValueError(typename)
            external_row[name] = value
    return rows
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号