datasets.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tensorflow-layer-library 作者: bioinf-jku 项目源码 文件源码
def batch_loader(self, rnd_gen=np.random, shuffle=True):
        """load_mbs yields a new minibatch at each iteration"""
        batchsize = self.batchsize
        inds = np.arange(self.n_samples)
        if shuffle:
            rnd_gen.shuffle(inds)
        n_mbs = np.int(np.ceil(self.n_samples / batchsize))

        x = np.zeros(self.X_shape, np.float32)
        y = np.zeros(self.y_shape, np.float32)
        ids = np.empty((batchsize,), np.object_)

        for m in range(n_mbs):
            start = m * batchsize
            end = (m + 1) * batchsize
            if end > self.n_samples:
                end = self.n_samples
            mb_slice = slice(start, end)

            x[:end - start, :] = self.x[inds[mb_slice], :]
            y[:end - start, :] = self.y[inds[mb_slice], :]
            ids[:end - start] = self.ids[inds[mb_slice]]

            yield dict(X=x, y=y, ID=ids)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号