def make_generator(files, batch_size, n_classes):
if batch_size % n_classes != 0:
raise ValueError("batch size must be divisible by num classes")
class_batch = batch_size // n_classes
generators = []
def get_epoch():
while True:
images = np.zeros((batch_size, 3, DIM, DIM), dtype='int32')
labels = np.zeros((batch_size, n_classes))
n=0
for style in styles:
styleLabel = styleNum[style]
curr = curPos[style]
for i in range(class_batch):
if curr == styles[style]:
curr = 0
random.shuffle(list(files[style]))
t0=time.time()
image = scipy.misc.imread("{}/{}/{}.png".format(path, style, str(curr)),mode='RGB')
#image = scipy.misc.imresize(image,(DIM,DIM))
images[n % batch_size] = image.transpose(2,0,1)
labels[n % batch_size, int(styleLabel)] = 1
n+=1
curr += 1
curPos[style]=curr
#randomize things but keep relationship between a conditioning vector and its associated image
rng_state = np.random.get_state()
np.random.shuffle(images)
np.random.set_state(rng_state)
np.random.shuffle(labels)
yield (images, labels)
return get_epoch
评论列表
文章目录