def binned_batch_stream(target_statistics, batch_size, n_batches, n_bins=64):
hist, bins = np.histogram(target_statistics, bins=n_bins)
indx = np.argsort(target_statistics)
indicies_categories = np.array_split(indx, np.cumsum(hist)[:-1])
per_category = batch_size / n_bins
weight_correction = (np.float64(hist) / per_category).astype('float32')
wc = np.repeat(weight_correction, per_category)
for i in xrange(n_batches):
sample = [
np.random.choice(ind, size=per_category, replace=True)
for ind in indicies_categories
]
yield np.hstack(sample), wc
评论列表
文章目录