def build_summaries(gan):
"""
"""
generator_loss_summary = tf.summary.scalar(
'generator loss', gan['generator_loss'])
discriminator_loss_summary = tf.summary.scalar(
'discriminator loss', gan['discriminator_loss'])
fake_grid = tf.reshape(gan['generator_fake'], [1, 64 * 32, 32, 1])
fake_grid = tf.split(fake_grid, 8, axis=1)
fake_grid = tf.concat(fake_grid, axis=2)
fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)
generator_fake_summary = tf.summary.image(
'generated image', fake_grid, max_outputs=18)
return {
'generator_fake_summary': generator_fake_summary,
'generator_loss_summary': generator_loss_summary,
'discriminator_loss_summary': discriminator_loss_summary,
}
python类saturate_cast()的实例源码
def build_summaries(model):
"""
build image summary: [source batch, target batch, result batch]
"""
keys = ['source_images', 'target_images', 'output_images']
images = tf.concat([model[k] for k in keys], axis=2)
images = tf.reshape(images, [1, FLAGS.batch_size * 256, 768, 3])
images = tf.saturate_cast(images * 127.5 + 127.5, tf.uint8)
summary = tf.summary.image('images', images, max_outputs=4)
return {
'summary': summary,
}
def build_summaries(gan_graph):
"""
"""
generator_loss_summary = tf.summary.scalar(
'generator loss', gan_graph['generator_loss'])
discriminator_loss_summary = tf.summary.scalar(
'discriminator loss', gan_graph['discriminator_loss'])
fake_grid = tf.reshape(gan_graph['generator_fake'], [1, 64 * 32, 32, 1])
fake_grid = tf.split(fake_grid, 8, axis=1)
fake_grid = tf.concat(fake_grid, axis=2)
fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)
generator_fake_summary = tf.summary.image(
'generated image', fake_grid, max_outputs=18)
return {
'generator_fake_summary': generator_fake_summary,
'generator_loss_summary': generator_loss_summary,
'discriminator_loss_summary': discriminator_loss_summary,
}
def build_summaries(gan_graph):
"""
"""
generator_loss_summary = tf.summary.scalar(
'generator loss', gan_graph['generator_loss'])
discriminator_loss_summary = tf.summary.scalar(
'discriminator loss', gan_graph['discriminator_loss'])
fake_grid = tf.reshape(gan_graph['generator_fake'], [1, 64 * 64, 64, 3])
fake_grid = tf.split(fake_grid, 8, axis=1)
fake_grid = tf.concat(fake_grid, axis=2)
fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)
generator_fake_summary = tf.summary.image(
'generated image', fake_grid, max_outputs=1)
return {
'generated_png': tf.image.encode_png(fake_grid[0]),
'generator_fake_summary': generator_fake_summary,
'generator_loss_summary': generator_loss_summary,
'discriminator_loss_summary': discriminator_loss_summary,
}
def build_summaries(network):
"""
"""
summaries = {}
real = network['real']
fake = network['fake']
cute = network['ae_output_fake']
image = tf.concat([real, fake, cute], axis=0)
grid = tf.reshape(image, [1, 3 * FLAGS.image_size, FLAGS.image_size, 3])
grid = tf.split(grid, 3, axis=1)
grid = tf.concat(grid, axis=2)
grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)
summaries['comparison'] = tf.summary.image('comp', grid, max_outputs=4)
return summaries
def build_summaries(model):
"""
"""
images_summary = []
generations = [
('summary_x_gx', 'xx_real', 'gx_fake'),
('summary_y_fy', 'yy_real', 'fy_fake')]
for g in generations:
images = tf.concat([model[g[1]], model[g[2]]], axis=2)
images = tf.reshape(images, [1, FLAGS.batch_size * 256, 512, 3])
images = tf.saturate_cast(images * 127.5 + 127.5, tf.uint8)
summary = tf.summary.image(g[0], images, max_outputs=4)
images_summary.append(summary)
#
summary_loss_d = tf.summary.scalar('d', model['loss_d'])
summary_loss_dx = tf.summary.scalar('dx', model['loss_dx'])
summary_loss_dy = tf.summary.scalar('dy', model['loss_dy'])
summary_d = \
tf.summary.merge([summary_loss_d, summary_loss_dx, summary_loss_dy])
summary_loss_g = tf.summary.scalar('g', model['loss_g'])
summary_loss_gx = tf.summary.scalar('gx', model['loss_gx'])
summary_loss_fy = tf.summary.scalar('fy', model['loss_fy'])
summary_g = \
tf.summary.merge([summary_loss_g, summary_loss_gx, summary_loss_fy])
return {
'images': tf.summary.merge(images_summary),
'loss_d': summary_d,
'loss_g': summary_g,
}
def translate():
"""
"""
image_path_pairs = prepare_paths()
reals = tf.placeholder(shape=[None, 256, 256, 3], dtype=tf.uint8)
flow = tf.cast(reals, dtype=tf.float32) / 127.5 - 1.0
model = build_cycle_gan(flow, flow, FLAGS.mode)
fakes = tf.saturate_cast(model['fake'] * 127.5 + 127.5, tf.uint8)
# path to checkpoint
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_dir_path)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
tf.train.Saver().restore(session, ckpt_source_path)
for i in range(0, len(image_path_pairs), FLAGS.batch_size):
path_pairs = image_path_pairs[i:i+FLAGS.batch_size]
real_images = [scipy.misc.imread(p[0]) for p in path_pairs]
fake_images = session.run(fakes, feed_dict={reals: real_images})
for idx, path in enumerate(path_pairs):
image = np.concatenate(
[real_images[idx], fake_images[idx]], axis=1)
scipy.misc.imsave(path[1], image)
def reshape_batch_images(batch_images):
"""
"""
batch_size = FLAGS.batch_size
image_size = FLAGS.image_size
# build summary for generated fake images.
grid = \
tf.reshape(batch_images, [1, batch_size * image_size, image_size, 3])
grid = tf.split(grid, FLAGS.summary_row_size, axis=1)
grid = tf.concat(grid, axis=2)
grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)
return grid
def build_image_grid(image_batch, row, col):
"""
Build an image grid from an image batch.
"""
image_size = FLAGS.image_size
grid = tf.reshape(
image_batch, [1, row * col * image_size, image_size, 3])
grid = tf.split(grid, col, axis=1)
grid = tf.concat(grid, axis=2)
grid = tf.saturate_cast(grid * 127.5 + 127.5, tf.uint8)
grid = tf.reshape(grid, [row * image_size, col * image_size, 3])
return grid
def build_summaries(network):
"""
"""
# summary_loss = tf.summary.scalar('transfer loss', network['loss'])
images_c = network['image_content']
images_s = network['image_styled']
images_c = tf.slice(
images_c,
[0, FLAGS.padding, FLAGS.padding, 0],
[-1, 256, 256, -1])
images_s = tf.slice(
images_s,
[0, FLAGS.padding, FLAGS.padding, 0],
[-1, 256, 256, -1])
images_c = tf.reshape(images_c, [1, FLAGS.batch_size * 256, 256, 3])
images_s = tf.reshape(images_s, [1, FLAGS.batch_size * 256, 256, 3])
images_a = tf.concat([images_c, images_s], axis=2)
images_a = images_a * 127.5 + 127.5
# images_a = tf.add(images_a, VggNet.mean_color_bgr())
images_a = tf.reverse(images_a, [3])
images_a = tf.saturate_cast(images_a, tf.uint8)
summary_image = tf.summary.image('all', images_a, max_outputs=4)
# summary_plus = tf.summary.merge([summary_image, summary_loss])
return {
# 'summary_part': summary_loss,
'summary_plus': summary_image,
}
def transfer_summary(vgg, loss, content_shape):
"""
summaries of loss and result image.
"""
image = tf.add(vgg.upstream, VggNet.mean_color_bgr())
image = tf.image.resize_images(image, content_shape)
image = tf.saturate_cast(image, tf.uint8)
image = tf.reverse(image, [3])
summary_image = tf.summary.image('generated image', image, max_outputs=1)
summary_loss = tf.summary.scalar('transfer loss', loss)
return tf.summary.merge([summary_image, summary_loss])
def main(argv=None):
if not FLAGS.CONTENT_IMAGES_PATH:
print "train a fast nerual style need to set the Content images path"
return
content_images = reader.image(
FLAGS.BATCH_SIZE,
FLAGS.IMAGE_SIZE,
FLAGS.CONTENT_IMAGES_PATH,
epochs=1,
shuffle=False,
crop=False)
generated_images = model.net(content_images / 255.)
output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8)
with tf.Session() as sess:
file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
if not file:
print('Could not find trained model in {0}'.format(FLAGS.MODEL_PATH))
return
print('Using model from {}'.format(file))
saver = tf.train.Saver()
saver.restore(sess, file)
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
i = 0
start_time = time.time()
try:
while not coord.should_stop():
print(i)
images_t = sess.run(output_format)
elapsed = time.time() - start_time
start_time = time.time()
print('Time for one batch: {}'.format(elapsed))
for raw_image in images_t:
i += 1
misc.imsave('out{0:04d}.png'.format(i), raw_image)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
def build_summaries(gan):
"""
"""
g_summaries = []
d_summaries = []
g_summaries.append(
tf.summary.scalar('generator loss', gan['generator_loss']))
d_summaries.append(
tf.summary.scalar('discriminator loss', gan['discriminator_loss']))
for vg in gan['generator_variables_gradients']:
variable_name = '{}/variable'.format(vg[0].name)
gradient_name = '{}/gradient'.format(vg[0].name)
g_summaries.append(tf.summary.histogram(variable_name, vg[0]))
g_summaries.append(tf.summary.histogram(gradient_name, vg[1]))
for vg in gan['discriminator_variables_gradients']:
variable_name = '{}/variable'.format(vg[0].name)
gradient_name = '{}/gradient'.format(vg[0].name)
d_summaries.append(tf.summary.histogram(variable_name, vg[0]))
d_summaries.append(tf.summary.histogram(gradient_name, vg[1]))
# fake image
image_width, image_depth = (64, 3) if FLAGS.use_lsun else (32, 1)
fake_grid = tf.reshape(
gan['generator_fake'],
[1, FLAGS.batch_size * image_width, image_width, image_depth])
fake_grid = tf.split(fake_grid, FLAGS.summary_col_size, axis=1)
fake_grid = tf.concat(fake_grid, axis=2)
fake_grid = tf.saturate_cast(fake_grid * 127.5 + 127.5, tf.uint8)
summary_generator_fake = tf.summary.image(
'generated image', fake_grid, max_outputs=1)
g_summaries_plus = g_summaries + [summary_generator_fake]
return {
'summary_generator': tf.summary.merge(g_summaries),
'summary_generator_plus': tf.summary.merge(g_summaries_plus),
'summary_discriminator': tf.summary.merge(d_summaries),
}
def generate():
if not FLAGS.CONTENT_IMAGE:
tf.logging.info("train a fast nerual style need to set the Content images path")
return
if not os.path.exists(FLAGS.OUTPUT_FOLDER):
os.mkdir(FLAGS.OUTPUT_FOLDER)
# ??????
height = 0
width = 0
with open(FLAGS.CONTENT_IMAGE, 'rb') as img:
with tf.Session().as_default() as sess:
if FLAGS.CONTENT_IMAGE.lower().endswith('png'):
image = sess.run(tf.image.decode_png(img.read()))
else:
image = sess.run(tf.image.decode_jpeg(img.read()))
height = image.shape[0]
width = image.shape[1]
tf.logging.info('Image size: %dx%d' % (width, height))
with tf.Graph().as_default(), tf.Session() as sess:
# ??????
path = FLAGS.CONTENT_IMAGE
png = path.lower().endswith('png')
img_bytes = tf.read_file(path)
# ????
content_image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3)
content_image = tf.image.convert_image_dtype(content_image, tf.float32) * 255.0
content_image = tf.expand_dims(content_image, 0)
generated_images = transform.net(content_image - vgg.MEAN_PIXEL, training=False)
output_format = tf.saturate_cast(generated_images, tf.uint8)
# ????
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
model_path = os.path.abspath(FLAGS.MODEL_PATH)
tf.logging.info('Usage model {}'.format(model_path))
saver.restore(sess, model_path)
filename = os.path.basename(FLAGS.CONTENT_IMAGE)
(shotname, extension) = os.path.splitext(filename)
filename = shotname + '-' + os.path.basename(FLAGS.MODEL_PATH) + extension
tf.logging.info("image {}".format(filename))
images_t = sess.run(output_format)
assert len(images_t) == 1
misc.imsave(os.path.join(FLAGS.OUTPUT_FOLDER, filename), images_t[0])
def transfer():
"""
"""
if tf.gfile.IsDirectory(FLAGS.ckpt_path):
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
elif tf.gfile.Exists(FLAGS.ckpt_path):
ckpt_source_path = FLAGS.ckpt_path
else:
assert False, 'bad checkpoint'
assert tf.gfile.Exists(FLAGS.content_path), 'bad content_path'
assert not tf.gfile.IsDirectory(FLAGS.content_path), 'bad content_path'
image_contents = build_contents_reader()
network = build_style_transfer_network(image_contents, training=False)
#
shape = tf.shape(network['image_styled'])
new_w = shape[1] - 2 * FLAGS.padding
new_h = shape[2] - 2 * FLAGS.padding
image_styled = tf.slice(
network['image_styled'],
[0, FLAGS.padding, FLAGS.padding, 0],
[-1, new_w, new_h, -1])
image_styled = tf.squeeze(image_styled, [0])
image_styled = image_styled * 127.5 + 127.5
image_styled = tf.reverse(image_styled, [2])
image_styled = tf.saturate_cast(image_styled, tf.uint8)
image_styled = tf.image.encode_jpeg(image_styled)
image_styled_writer = tf.write_file(FLAGS.styled_path, image_styled)
with tf.Session() as session:
tf.train.Saver().restore(session, ckpt_source_path)
# make dataset reader work
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
session.run(image_styled_writer)
coord.request_stop()
coord.join(threads)