def buckets(x, y, size=50):
assert len(x[0]) == len(y[0])
num_inputs = len(x)
samples = x + y
num_items = len(samples)
xy = zip(*samples)
xy.sort(key=lambda i: len(i[0]))
t_len = size
idx = 0
bucks = [[[]] for _ in range(num_items)]
for item in xy:
if len(item[0]) > t_len:
if len(bucks[0][idx]) > 0:
for buck in bucks:
buck.append([])
idx += 1
while len(item[0]) > t_len:
t_len += size
for i in range(num_items):
#print item[i]
bucks[i][idx].append(item[i])
return bucks[:num_inputs], bucks[num_inputs:]
评论列表
文章目录