def train(gen1, gen2, dis, optimizer_gen, optimizer_dis, images, epoch_num, output_path, lr_decay=10, save_epoch=1, batch_size=64, margin=20, out_image_dir=None, clip_rect=None):
xp = gen1.xp
out_image_row_num = 10
out_image_col_num = 10
z_out_image = xp.random.normal(0, 1, (out_image_row_num * out_image_col_num, latent_size)).astype(np.float32)
z_out_image = z_out_image / (xp.linalg.norm(z_out_image, axis=1, keepdims=True) + 1e-12)
x_batch = np.zeros((batch_size, 3, image_size, image_size), dtype=np.float32)
iterator = chainer.iterators.SerialIterator(images, batch_size)
sum_loss_gen = 0
sum_loss_dis = 0
num_loss = 0
last_clock = time.clock()
for batch_images in iterator:
for j, image in enumerate(batch_images):
with io.BytesIO(image) as b:
pixels = Image.open(b).convert('RGB')
if clip_rect is not None:
offset_left = np.random.randint(-4, 5)
offset_top = np.random.randint(-4, 5)
pixels = pixels.crop((clip_rect[0] + offset_left, clip_rect[1] + offset_top) + clip_rect[2:])
pixels = np.asarray(pixels.resize((image_size, image_size)), dtype=np.float32)
pixels = pixels.transpose((2, 0, 1))
x_batch[j,...] = pixels / 127.5 - 1
loss_gen, loss_dis = update(gen1, gen2, dis, optimizer_gen, optimizer_dis, x_batch, margin)
sum_loss_gen += loss_gen
sum_loss_dis += loss_dis
num_loss += 1
if iterator.is_new_epoch:
epoch = iterator.epoch
current_clock = time.clock()
print('epoch {} done {}s elapsed'.format(epoch, current_clock - last_clock))
print('gen loss: {}'.format(sum_loss_gen / num_loss))
print('dis loss: {}'.format(sum_loss_dis / num_loss))
last_clock = current_clock
sum_loss_gen = 0
sum_loss_dis = 0
num_loss = 0
if iterator.epoch % lr_decay == 0:
optimizer_gen.alpha *= 0.5
optimizer_dis.alpha *= 0.5
if iterator.epoch % save_epoch == 0:
if out_image_dir is not None:
image = np.zeros((out_image_row_num * out_image_col_num, 3, image_size, image_size), dtype=np.uint8)
for i in six.moves.range(out_image_row_num):
with chainer.no_backprop_mode():
begin_index = i * out_image_col_num
end_index = (i + 1) * out_image_col_num
sub_image = gen2(gen1(z_out_image[begin_index:end_index], train=False), train=False).data
sub_image = ((cuda.to_cpu(sub_image) + 1) * 127.5)
image[begin_index:end_index, ...] = sub_image.clip(0, 255).astype(np.uint8)
image = image.reshape(out_image_row_num, out_image_col_num, 3, image_size, image_size)
image = image.transpose((0, 3, 1, 4, 2))
image = image.reshape((out_image_row_num * image_size, out_image_col_num * image_size, 3))
Image.fromarray(image).save(os.path.join(out_image_dir, '{0:04d}.png'.format(epoch)))
serializers.save_npz('{0}_{1:03d}.gen.model'.format(output_path, epoch), gen2)
serializers.save_npz('{0}_{1:03d}.gen.state'.format(output_path, epoch), optimizer_gen)
serializers.save_npz('{0}_{1:03d}.dis.model'.format(output_path, epoch), dis)
serializers.save_npz('{0}_{1:03d}.dis.state'.format(output_path, epoch), optimizer_dis)
if iterator.epoch >= epoch_num:
break
评论列表
文章目录