def get_batch(image_data, batch_size, batch_index):
nrof_examples = np.size(image_data, 0)
j = batch_index*batch_size % nrof_examples
if j+batch_size<=nrof_examples:
batch = image_data[j:j+batch_size,:,:,:]
else:
x1 = image_data[j:nrof_examples,:,:,:]
x2 = image_data[0:nrof_examples-j,:,:,:]
batch = np.vstack([x1,x2])
batch_float = batch.astype(np.float32)
return batch_float
评论列表
文章目录