WGAN_mnist.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Keras-GAN 作者: Shaofanl 项目源码 文件源码
def get_mnist(nbatch=128):
    mnist = fetch_mldata('MNIST original', data_home='/home/shaofan/.sklearn/') 
    x, y = mnist.data, mnist.target
    x = x.reshape(-1, 1, 28, 28)

    ind = np.random.permutation(x.shape[0])
    x = x[ind]
    y = y[ind]

    def random_stream():
        while 1:
            yield x[np.random.choice(x.shape[0], replace=False, size=nbatch)].transpose(0, 2, 3, 1)
    return x, y, random_stream
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号