def main(_):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
model = UNet(args.experiment_dir, batch_size=args.batch_size, experiment_id=args.experiment_id,
input_width=args.image_size, output_width=args.image_size, embedding_num=args.embedding_num,
embedding_dim=args.embedding_dim, L1_penalty=args.L1_penalty, Lconst_penalty=args.Lconst_penalty,
Ltv_penalty=args.Ltv_penalty, Lcategory_penalty=args.Lcategory_penalty)
model.register_session(sess)
if args.flip_labels:
model.build_model(is_training=True, inst_norm=args.inst_norm, no_target_source=True)
else:
model.build_model(is_training=True, inst_norm=args.inst_norm)
fine_tune_list = None
if args.fine_tune:
ids = args.fine_tune.split(",")
fine_tune_list = set([int(i) for i in ids])
model.train(lr=args.lr, epoch=args.epoch, resume=args.resume,
schedule=args.schedule, freeze_encoder=args.freeze_encoder, fine_tune=fine_tune_list,
sample_steps=args.sample_steps, checkpoint_steps=args.checkpoint_steps,
flip_labels=args.flip_labels, no_val=args.no_val)
评论列表
文章目录