def batch_loader(self, rnd_gen=np.random, shuffle=True):
"""load_mbs yields a new minibatch at each iteration"""
batchsize = self.batchsize
inds = np.arange(self.n_samples)
if shuffle:
rnd_gen.shuffle(inds)
n_mbs = np.int(np.ceil(self.n_samples / batchsize))
x = np.zeros(self.X_shape, np.float32)
y = np.zeros(self.y_shape, np.float32)
ids = np.empty((batchsize,), np.object_)
for m in range(n_mbs):
start = m * batchsize
end = (m + 1) * batchsize
if end > self.n_samples:
end = self.n_samples
mb_slice = slice(start, end)
x[:end - start, :] = self.x[inds[mb_slice], :]
y[:end - start, :] = self.y[inds[mb_slice], :]
ids[:end - start] = self.ids[inds[mb_slice]]
yield dict(X=x, y=y, ID=ids)
评论列表
文章目录