def main():
args = parse_args()
gen = net.Generator1()
dis = net.Discriminator1()
clip_rect = None
if args.clip_rect:
clip_rect = map(int, args.clip_rect.split(','))
clip_rect = tuple([clip_rect[0], clip_rect[1], clip_rect[0] + clip_rect[2], clip_rect[1] + clip_rect[3]])
gpu_device = None
if args.gpu >= 0:
device_id = args.gpu
cuda.get_device(device_id).use()
gen.to_gpu(device_id)
dis.to_gpu(device_id)
optimizer_gen = optimizers.Adam(alpha=0.001)
optimizer_gen.setup(gen)
optimizer_dis = optimizers.Adam(alpha=0.001)
optimizer_dis.setup(dis)
if args.input != None:
serializers.load_npz(args.input + '.gen.model', gen)
serializers.load_npz(args.input + '.gen.state', optimizer_gen)
serializers.load_npz(args.input + '.dis.model', dis)
serializers.load_npz(args.input + '.dis.state', optimizer_dis)
if args.out_image_dir != None:
if not os.path.exists(args.out_image_dir):
try:
os.mkdir(args.out_image_dir)
except:
print 'cannot make directory {}'.format(args.out_image_dir)
exit()
elif not os.path.isdir(args.out_image_dir):
print 'file path {} exists but is not directory'.format(args.out_image_dir)
exit()
with open(args.dataset, 'rb') as f:
images = pickle.load(f)
train(gen, dis, optimizer_gen, optimizer_dis, images, args.epoch, batch_size=args.batch_size, margin=args.margin, save_epoch=args.save_epoch, lr_decay=args.lr_decay, output_path=args.output, out_image_dir=args.out_image_dir, clip_rect=clip_rect)
评论列表
文章目录