utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:chinese_generation 作者: polaroidz 项目源码 文件源码
def batch_generator(batch_size, nb_batches):
    batch_count = 0

    while True:
        pos = batch_count * batch_size
        batch = dataset[pos:pos+batch_size]

        X = np.zeros((batch_size, 1, img_size, img_size), dtype=np.float32)

        for k, path in enumerate(batch):
            im = io.imread(path)
            im = color.rgb2gray(im)

            X[k] = im[np.newaxis, ...]

        X = torch.from_numpy(X)
        X = Variable(X)

        yield X, batch

        batch_count += 1

        if batch_count > nb_batches:
            batch_count = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号