def main(argv=None):
if not FLAGS.CONTENT_IMAGES_PATH:
print "train a fast nerual style need to set the Content images path"
return
content_images = reader.image(
FLAGS.BATCH_SIZE,
FLAGS.IMAGE_SIZE,
FLAGS.CONTENT_IMAGES_PATH,
epochs=1,
shuffle=False,
crop=False)
generated_images = model.net(content_images / 255.)
output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8)
with tf.Session() as sess:
file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
if not file:
print('Could not find trained model in {0}'.format(FLAGS.MODEL_PATH))
return
print('Using model from {}'.format(file))
saver = tf.train.Saver()
saver.restore(sess, file)
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
i = 0
start_time = time.time()
try:
while not coord.should_stop():
print(i)
images_t = sess.run(output_format)
elapsed = time.time() - start_time
start_time = time.time()
print('Time for one batch: {}'.format(elapsed))
for raw_image in images_t:
i += 1
misc.imsave('out{0:04d}.png'.format(i), raw_image)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
评论列表
文章目录