def get_batches(data, batch_size, vocabulary, pos_vocabulary):
'''
Get batches without any restrictions on number of antecedents and negative candidates.
'''
random.seed(24)
random.shuffle(data)
data_size = len(data)
if data_size % float(batch_size) == 0:
num_batches = int(data_size / float(batch_size))
else:
num_batches = int(data_size / float(batch_size)) + 1
batches = []
for batch_num in range(num_batches):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
batch = pad_batch(data[start_index:end_index], vocabulary, pos_vocabulary)
batches.append(batch)
logging.info('Data size: %s' % len(data))
logging.info('Number of batches: %s' % len(batches))
return batches
评论列表
文章目录