wikiartGenre.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:GANGogh 作者: rkjones4 项目源码 文件源码
def make_generator(files, batch_size, n_classes):
    if batch_size % n_classes != 0:
        raise ValueError("batch size must be divisible by num classes")

    class_batch = batch_size // n_classes

    generators = []

    def get_epoch():

        while True:

            images = np.zeros((batch_size, 3, DIM, DIM), dtype='int32')
            labels = np.zeros((batch_size, n_classes))
            n=0
            for style in styles:
                styleLabel = styleNum[style]
                curr = curPos[style]
                for i in range(class_batch):
                    if curr == styles[style]:
                        curr = 0
                        random.shuffle(list(files[style]))
                    t0=time.time()
                    image = scipy.misc.imread("{}/{}/{}.png".format(path, style, str(curr)),mode='RGB')
                    #image = scipy.misc.imresize(image,(DIM,DIM))
                    images[n % batch_size] = image.transpose(2,0,1)
                    labels[n % batch_size, int(styleLabel)] = 1
                    n+=1
                    curr += 1
                curPos[style]=curr

            #randomize things but keep relationship between a conditioning vector and its associated image
            rng_state = np.random.get_state()
            np.random.shuffle(images)
            np.random.set_state(rng_state)
            np.random.shuffle(labels)
            yield (images, labels)



    return get_epoch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号