def main(args):
with tf.Graph().as_default() as graph:
# Create dataset
logging.info('Create data flow from %s' % args.train)
train_data = Dataset(directory=args.train, mean_path=args.mean, batch_size=args.batch_size, num_threads=2, capacity=10000)
# Create initializer
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
# Config session
config = get_config(args)
# Setup summary
check_summary_writer = tf.summary.FileWriter(os.path.join(args.log, 'check'), graph)
check_op = tf.cast(train_data()['x_t_1'] * 255.0 + train_data()['mean'], tf.uint8)
tf.summary.image('x_t_1_batch_restore', check_op, collections=['check'])
check_summary_op = tf.summary.merge_all('check')
# Start session
with tf.Session(config=config) as sess:
coord = tf.train.Coordinator()
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
x_t_1_batch, summary = sess.run([check_op, check_summary_op])
check_summary_writer.add_summary(summary, i)
coord.request_stop()
coord.join(threads)
check.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录