def next_batch(self):
self.count += 1
# print self.count
start = self.index_in_epoch
self.index_in_epoch += batch_size / pairs_per_img
if self.index_in_epoch > self.number:
self.index_in_epoch = 0
start = self.index_in_epoch
self.index_in_epoch += batch_size / pairs_per_img
end = self.index_in_epoch
data_batch, label_batch = generate_data(self.img_path_list[start])
for i in range(start+1, end):
data, label = generate_data(self.img_path_list[i]) # [4, 2, 128, 128], [4, 1, 8]
data_batch = np.concatenate((data_batch, data)) # [64, 2, 128, 128]
label_batch = np.concatenate((label_batch, label)) # [64, 1, 8]
data_batch = np.array(data_batch).transpose([0, 2, 3, 1]) # (64, 128, 128, 2)
# cv2.imshow('window2', data_batch[1,:,:,1].squeeze())
# cv2.waitKey()
label_batch = np.array(label_batch).squeeze() # (64, 1, 8)
return data_batch, label_batch
评论列表
文章目录