def get_label_batch(label_data, batch_size, batch_index):
nrof_examples = np.size(label_data, 0)
j = batch_index*batch_size % nrof_examples
if j+batch_size<=nrof_examples:
batch = label_data[j:j+batch_size]
else:
x1 = label_data[j:nrof_examples]
x2 = label_data[0:nrof_examples-j]
batch = np.vstack([x1,x2])
batch_int = batch.astype(np.int64)
return batch_int
评论列表
文章目录