def train():
"""
"""
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
xx_real = build_image_batch_reader(
FLAGS.x_images_dir_path, FLAGS.batch_size)
yy_real = build_image_batch_reader(
FLAGS.y_images_dir_path, FLAGS.batch_size)
image_pool = {}
model = build_cycle_gan(xx_real, yy_real, '')
summaries = build_summaries(model)
reporter = tf.summary.FileWriter(FLAGS.logs_dir_path)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
if ckpt_source_path is not None:
tf.train.Saver().restore(session, ckpt_source_path)
# give up overlapped old data
step = session.run(model['step'])
reporter.add_session_log(
tf.SessionLog(status=tf.SessionLog.START),
global_step=step)
# make dataset reader work
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
while train_one_step(model, summaries, image_pool, reporter):
pass
coord.request_stop()
coord.join(threads)
评论列表
文章目录