def put_in_buckets(data_array, labels, buckets, mode='pad'):
"""
Given bucket edges and data, put the data in buckets according to their length
:param data_array:
:param labels:
:param buckets:
:return:
"""
input_lengths = np.array([len(s) for s in data_array], dtype='int')
input_bucket_index = [i if i<len(buckets) else len(buckets)-1 for i in np.digitize(input_lengths, buckets, right=False)] # during testing, longer sentences are just truncated
if mode == 'truncate':
input_bucket_index -= 1
bucketed_data = {}
reordering_indexes = {}
for bucket in list(np.unique(input_bucket_index)):
length_indexes = np.where(input_bucket_index == bucket)[0]
reordering_indexes[bucket] = length_indexes
maxlen = int(np.floor(buckets[bucket]))
padded = pad_data(data_array[length_indexes], labels[length_indexes], max_len=maxlen)
bucketed_data[bucket] = padded # in final dict, start counting by zero
return bucketed_data, reordering_indexes
评论列表
文章目录