def main():
window = 'preview'
cv2.namedWindow(window)
tfrecord_file_names = glob(path.join('data', '*-2.tfrecord.gz'))
max_reads = 200
batch_size = 50
with tf.Graph().as_default() as graph:
image_batch, type_batch = import_images(tfrecord_file_names, max_reads=max_reads, batch_size=batch_size)
import_graph('exported/vae-refine.pb', input_map={'image_batch': image_batch}, prefix='process')
phase_train = graph.get_tensor_by_name('process/mogrify/vae/phase_train:0')
embedding = graph.get_tensor_by_name('process/mogrify/vae/variational/add:0')
reconstructed = graph.get_tensor_by_name('process/mogrify/clip:0')
reconstructed.set_shape((None, 180, 320, 3))
refined = graph.get_tensor_by_name('process/refine/y:0')
refined.set_shape((None, 180, 320, 3))
coord = tf.train.Coordinator()
with tf.Session(graph=graph) as sess:
init = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
print('Evaluating ...')
while not coord.should_stop():
# fetching the embeddings given the inputs ...
reference, coeffs = sess.run([image_batch, embedding], feed_dict={phase_train: False})
# ... then salting the embeddings ...
coeffs += np.random.randn(coeffs.shape[0], coeffs.shape[1])
# ... then fetching the images given the new embeddings.
results = sess.run(refined, feed_dict={phase_train: False, embedding: coeffs})
assert reference.shape == results.shape
reference = reference[:3]
results = results[:3]
canvas = example_gallery(reference, results)
cv2.imshow(window, canvas)
if (cv2.waitKey(1000) & 0xff) == 27:
print('User requested cancellation.')
coord.request_stop()
break
except tf.errors.OutOfRangeError:
print('Read all examples.')
finally:
coord.request_stop()
coord.join(threads)
coord.wait_for_stop()
cv2.destroyWindow(window)
评论列表
文章目录