train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-stack-gan 作者: dsanno 项目源码 文件源码
def main():
    args = parse_args()
    gen = net.Generator1()
    dis = net.Discriminator1()
    clip_rect = None
    if args.clip_rect:
        clip_rect = map(int, args.clip_rect.split(','))
        clip_rect = tuple([clip_rect[0], clip_rect[1], clip_rect[0] + clip_rect[2], clip_rect[1] + clip_rect[3]])

    gpu_device = None
    if args.gpu >= 0:
        device_id = args.gpu
        cuda.get_device(device_id).use()
        gen.to_gpu(device_id)
        dis.to_gpu(device_id)

    optimizer_gen = optimizers.Adam(alpha=0.001)
    optimizer_gen.setup(gen)
    optimizer_dis = optimizers.Adam(alpha=0.001)
    optimizer_dis.setup(dis)

    if args.input != None:
        serializers.load_npz(args.input + '.gen.model', gen)
        serializers.load_npz(args.input + '.gen.state', optimizer_gen)
        serializers.load_npz(args.input + '.dis.model', dis)
        serializers.load_npz(args.input + '.dis.state', optimizer_dis)

    if args.out_image_dir != None:
        if not os.path.exists(args.out_image_dir):
            try:
                os.mkdir(args.out_image_dir)
            except:
                print 'cannot make directory {}'.format(args.out_image_dir)
                exit()
        elif not os.path.isdir(args.out_image_dir):
            print 'file path {} exists but is not directory'.format(args.out_image_dir)
            exit()

    with open(args.dataset, 'rb') as f:
        images = pickle.load(f)

    train(gen, dis, optimizer_gen, optimizer_dis, images, args.epoch, batch_size=args.batch_size, margin=args.margin, save_epoch=args.save_epoch, lr_decay=args.lr_decay, output_path=args.output, out_image_dir=args.out_image_dir, clip_rect=clip_rect)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号