def create_batches(data, batch_size, padding_id, label=True, sort=True, shuffle=True):
if label:
for d in data:
assert d[1] != -1
if sort:
data = sorted(data, key=lambda x: len(x[0]), reverse=True)
batches = []
for i in xrange(0, len(data), batch_size):
#idxs, idys
input_lst = create_input(data[i:i+batch_size], padding_id)
batches.append(input_lst)
if shuffle:
idx = np.random.permutation(len(batches))
new_batches = [batches[i] for i in idx]
new_data = reduce(operator.add, [data[i*batch_size:(i+1)*batch_size] for i in idx])
batches, data = new_batches, new_data
assert len(new_data) == len(data)
if not label:
# set all label to 0
for b in batches:
b[1][:] = 0
return batches, data
评论列表
文章目录