format.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def import_data(data_csvs_in,
                types_csv_in,
                values_csv_in,
                groups_csv_in,
                dataset_out,
                encoding='utf-8'):
    """Import a comma-delimited list of csv files into internal treecat format.

    Common encodings include: utf-8, cp1252.
    """
    schema = load_schema(types_csv_in, values_csv_in, groups_csv_in, encoding)
    data = np.concatenate([
        load_data(schema, data_csv_in, encoding)
        for data_csv_in in data_csvs_in.split(',')
    ])
    data.flags.writeable = False
    print('Imported data shape: [{}, {}]'.format(data.shape[0], data.shape[1]))
    ragged_index = schema['ragged_index']
    for v, name in enumerate(schema['feature_names']):
        beg, end = ragged_index[v:v + 2]
        count = np.count_nonzero(data[:, beg:end].max(1))
        if count == 0:
            print('WARNING: No values found for feature {}'.format(name))
    feature_types = [TY_MULTINOMIAL] * len(schema['feature_names'])
    table = Table(feature_types, ragged_index, data)
    dataset = {
        'schema': schema,
        'table': table,
    }
    pickle_dump(dataset, dataset_out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号