batch_streams.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:crayimage 作者: yandexdataschool 项目源码 文件源码
def binned_batch_stream(target_statistics, batch_size, n_batches, n_bins=64):
  hist, bins = np.histogram(target_statistics, bins=n_bins)
  indx = np.argsort(target_statistics)
  indicies_categories = np.array_split(indx, np.cumsum(hist)[:-1])
  n_samples = target_statistics.shape[0]

  per_category = batch_size / n_bins

  weight_correction = (n_bins * np.float64(hist) / n_samples).astype('float32')
  wc = np.repeat(weight_correction, per_category)

  for i in xrange(n_batches):
    sample = [
      np.random.choice(ind, size=per_category, replace=True)
      for ind in indicies_categories
    ]

    yield np.hstack(sample), wc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号