generate.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:chainer-stack-gan 作者: dsanno 项目源码 文件源码
def main():
    args = parse_args()
    gen1 = net.Generator1()
    chainer.serializers.load_npz(args.model_path, gen1)
    device_id = None
    if args.gpu >= 0:
        device_id = args.gpu
        cuda.get_device(device_id).use()
        gen1.to_gpu(device_id)

    out_vector_path = None
    np.random.seed(1)
    if args.vector_file1 and args.vector_index1 >= 0 and args.vector_file2 and args.vector_index2 >= 0:
        with open(args.vector_file1, 'rb') as f:
            z = np.load(f)
            z1 = z[args.vector_index1]
        with open(args.vector_file2, 'rb') as f:
            z = np.load(f)
            z2 = z[args.vector_index2]
        w = np.arange(10).astype(np.float32).reshape((-1, 1)) / 9
        z = (1 - w) * z1 + w * z2
        z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-12)
    else:
        z = np.random.normal(0, 1, (100, latent_size)).astype(np.float32)
        out_vector_path = '{}.npy'.format(args.output)
        z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-12)

    with chainer.no_backprop_mode():
        if device_id is None:
            x = gen1(z, train=False)
        else:
            x = gen1(cuda.to_gpu(z, device_id), train=False)
    x = cuda.to_cpu(x.data)
    batch, ch, h, w = x.shape
    x = x.reshape((-1, 10, ch, h, w)).transpose((0, 3, 1, 4, 2)).reshape((-1, 10 * w, ch))
    x = ((x + 1) * 127.5).clip(0, 255).astype(np.uint8)
    Image.fromarray(x).save('{}.jpg'.format(args.output))
    if out_vector_path:
        with open(out_vector_path, 'wb') as f:
            np.save(f, z)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号