def batch(self):
"""Return a batch of samples sampled uniformly from the database.
Returns
-------
(numpy.ndarray, ...)
The sample values are returned in a tuple in the order of the
`keys` specified by the user.
"""
# Count the number of keys (i.e. data objects)
nb_keys = len(self.keys)
data = []
for key in self.keys:
data.append(np.zeros((self.batch_size,) + self.spec[key]['shape'],
dtype=self.spec[key]['dtype']))
while True:
# Sample indices uniformly
batch_idxs = self.rng.randint(self.db.nb_samples,
size=self.batch_size,
dtype=np.uint64)
for i, v in enumerate(batch_idxs):
sample = self.db.get_sample(v)
for k in range(nb_keys):
data[k][i] = sample[self.keys[k]]
# Account for batches with only one key
if 1 == len(data):
yield tuple(data)[0]
else:
yield tuple(data)
评论列表
文章目录