def import_data(data_csvs_in,
types_csv_in,
values_csv_in,
groups_csv_in,
dataset_out,
encoding='utf-8'):
"""Import a comma-delimited list of csv files into internal treecat format.
Common encodings include: utf-8, cp1252.
"""
schema = load_schema(types_csv_in, values_csv_in, groups_csv_in, encoding)
data = np.concatenate([
load_data(schema, data_csv_in, encoding)
for data_csv_in in data_csvs_in.split(',')
])
data.flags.writeable = False
print('Imported data shape: [{}, {}]'.format(data.shape[0], data.shape[1]))
ragged_index = schema['ragged_index']
for v, name in enumerate(schema['feature_names']):
beg, end = ragged_index[v:v + 2]
count = np.count_nonzero(data[:, beg:end].max(1))
if count == 0:
print('WARNING: No values found for feature {}'.format(name))
feature_types = [TY_MULTINOMIAL] * len(schema['feature_names'])
table = Table(feature_types, ragged_index, data)
dataset = {
'schema': schema,
'table': table,
}
pickle_dump(dataset, dataset_out)
评论列表
文章目录