train2.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainer-stack-gan 作者: dsanno 项目源码 文件源码
def train(gen1, gen2, dis, optimizer_gen, optimizer_dis, images, epoch_num, output_path, lr_decay=10, save_epoch=1, batch_size=64, margin=20, out_image_dir=None, clip_rect=None):
    xp = gen1.xp
    out_image_row_num = 10
    out_image_col_num = 10
    z_out_image =  xp.random.normal(0, 1, (out_image_row_num * out_image_col_num, latent_size)).astype(np.float32)
    z_out_image = z_out_image / (xp.linalg.norm(z_out_image, axis=1, keepdims=True) + 1e-12)
    x_batch = np.zeros((batch_size, 3, image_size, image_size), dtype=np.float32)
    iterator = chainer.iterators.SerialIterator(images, batch_size)
    sum_loss_gen = 0
    sum_loss_dis = 0
    num_loss = 0
    last_clock = time.clock()
    for batch_images in iterator:
        for j, image in enumerate(batch_images):
            with io.BytesIO(image) as b:
                pixels = Image.open(b).convert('RGB')
                if clip_rect is not None:
                    offset_left = np.random.randint(-4, 5)
                    offset_top = np.random.randint(-4, 5)
                    pixels = pixels.crop((clip_rect[0] + offset_left, clip_rect[1] + offset_top) + clip_rect[2:])
                pixels = np.asarray(pixels.resize((image_size, image_size)), dtype=np.float32)
                pixels = pixels.transpose((2, 0, 1))
                x_batch[j,...] = pixels / 127.5 - 1
        loss_gen, loss_dis = update(gen1, gen2, dis, optimizer_gen, optimizer_dis, x_batch, margin)
        sum_loss_gen += loss_gen
        sum_loss_dis += loss_dis
        num_loss += 1
        if iterator.is_new_epoch:
            epoch = iterator.epoch
            current_clock = time.clock()
            print('epoch {} done {}s elapsed'.format(epoch, current_clock - last_clock))
            print('gen loss: {}'.format(sum_loss_gen / num_loss))
            print('dis loss: {}'.format(sum_loss_dis / num_loss))
            last_clock = current_clock
            sum_loss_gen = 0
            sum_loss_dis = 0
            num_loss = 0
            if iterator.epoch % lr_decay == 0:
                optimizer_gen.alpha *= 0.5
                optimizer_dis.alpha *= 0.5
            if iterator.epoch % save_epoch == 0:
                if out_image_dir is not None:
                    image = np.zeros((out_image_row_num * out_image_col_num, 3, image_size, image_size), dtype=np.uint8)
                    for i in six.moves.range(out_image_row_num):
                        with chainer.no_backprop_mode():
                            begin_index = i * out_image_col_num
                            end_index = (i + 1) * out_image_col_num
                            sub_image = gen2(gen1(z_out_image[begin_index:end_index], train=False), train=False).data
                            sub_image = ((cuda.to_cpu(sub_image) + 1) * 127.5)
                            image[begin_index:end_index, ...] = sub_image.clip(0, 255).astype(np.uint8)
                    image = image.reshape(out_image_row_num, out_image_col_num, 3, image_size, image_size)
                    image = image.transpose((0, 3, 1, 4, 2))
                    image = image.reshape((out_image_row_num * image_size, out_image_col_num * image_size, 3))
                    Image.fromarray(image).save(os.path.join(out_image_dir, '{0:04d}.png'.format(epoch)))
                serializers.save_npz('{0}_{1:03d}.gen.model'.format(output_path, epoch), gen2)
                serializers.save_npz('{0}_{1:03d}.gen.state'.format(output_path, epoch), optimizer_gen)
                serializers.save_npz('{0}_{1:03d}.dis.model'.format(output_path, epoch), dis)
                serializers.save_npz('{0}_{1:03d}.dis.state'.format(output_path, epoch), optimizer_dis)
            if iterator.epoch >= epoch_num:
                break
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号