def get_next_batch(self, mode, idx):
"""
return next batch of data samples
"""
batch_size = self.args.batch_size
if mode == "train":
dataset = self.train_data
sample_num = self.train_sample_num
elif mode == "valid":
dataset = self.valid_data
sample_num = self.valid_sample_num
else:
dataset = self.test_data
sample_num = self.test_sample_num
if mode == "train":
start = self.train_idx[idx] * batch_size
stop = (self.train_idx[idx] + 1) * batch_size
else:
start = idx * batch_size
stop = (idx + 1) * batch_size if start < sample_num and (idx + 1) * batch_size < sample_num else -1
samples = batch_size if stop != -1 else len(dataset[0]) - start
_slice = np.index_exp[start:stop]
return self.next_batch_feed_dict_by_dataset(dataset, _slice, samples)
评论列表
文章目录