def run(args):
# setting the GPU #
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.info('Read data:')
train_A, train_B, test_A, test_B = get_data(args.task, args.image_size)
logger.info('Build graph:')
model = BicycleGAN(args)
variables_to_save = tf.global_variables()
init_op = tf.variables_initializer(variables_to_save)
init_all_op = tf.global_variables_initializer()
saver = FastSaver(variables_to_save)
logger.info('Trainable vars:')
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
for v in var_list:
logger.info(' %s %s', v.name, v.get_shape())
if args.load_model != '':
model_name = args.load_model
else:
model_name = '{}_{}'.format(args.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
logdir = './logs'
makedirs(logdir)
logdir = os.path.join(logdir, model_name)
logger.info('Events directory: %s', logdir)
summary_writer = tf.summary.FileWriter(logdir)
makedirs('./results')
def init_fn(sess):
logger.info('Initializing all parameters.')
sess.run(init_all_op)
sv = tf.train.Supervisor(is_chief=True,
logdir=logdir,
saver=saver,
summary_op=None,
init_op=init_op,
init_fn=init_fn,
summary_writer=summary_writer,
ready_op=tf.report_uninitialized_variables(variables_to_save),
global_step=model.global_step,
save_model_secs=300,
save_summaries_secs=30)
if args.train:
logger.info("Starting training session.")
with sv.managed_session() as sess:
model.train(sess, summary_writer, train_A, train_B)
logger.info("Starting testing session.")
with sv.managed_session() as sess:
base_dir = os.path.join('results', model_name)
makedirs(base_dir)
model.test(sess, test_A, test_B, base_dir)
评论列表
文章目录